From: Geliang Tang <tanggeliang@kylinos.cn>
A pointer to struct tls_prot_ops, named 'ops', has been added to struct
tls_proto.
In tls_build_proto(), proto->ops is assigned either 'tls_mptcp_ops' or
'tls_tcp_ops' based on the socket protocol. Fix module reference counting
bug where each socket release called module_put() without matching get
for existing tls_proto.
Co-developed-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Geliang Tang <tanggeliang@kylinos.cn>
---
include/net/tls.h | 1 +
net/tls/tls_main.c | 25 ++++++++++++++++++++++++-
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/include/net/tls.h b/include/net/tls.h
index 0865932d8cc7..ee24f9d24324 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -249,6 +249,7 @@ struct tls_proto {
refcount_t refcnt;
struct list_head list;
const struct proto *prot;
+ const struct tls_prot_ops *ops;
struct proto prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
struct proto_ops proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
};
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 1ee891405cde..68308a42899b 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -145,6 +145,7 @@ static void tls_proto_cleanup(void)
if (refcount_dec_and_test(&prot->refcnt)) {
list_del_rcu(&prot->list);
synchronize_rcu();
+ module_put(prot->ops->owner);
kfree(prot);
}
}
@@ -367,9 +368,11 @@ void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
return;
if (ctx->proto) {
+ module_put(ctx->proto->ops->owner);
if (refcount_dec_and_test(&ctx->proto->refcnt)) {
list_del_rcu(&ctx->proto->list);
synchronize_rcu();
+ module_put(ctx->proto->ops->owner);
kfree(ctx->proto);
}
}
@@ -1021,6 +1024,7 @@ static struct tls_proto *tls_build_proto(struct sock *sk)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
struct proto *prot = READ_ONCE(sk->sk_prot);
+ struct tls_prot_ops *ops;
struct tls_proto *proto;
mutex_lock(&tls_proto_mutex);
@@ -1028,11 +1032,22 @@ static struct tls_proto *tls_build_proto(struct sock *sk)
if (proto)
goto out;
+ rcu_read_lock();
+ ops = tls_prot_ops_find(sk->sk_protocol);
+ if (!ops || !try_module_get(ops->owner)) {
+ rcu_read_unlock();
+ goto out;
+ }
+ rcu_read_unlock();
+
proto = kzalloc(sizeof(*proto), GFP_KERNEL);
- if (!proto)
+ if (!proto) {
+ module_put(ops->owner);
goto out;
+ }
proto->prot = prot;
+ proto->ops = ops;
refcount_set(&proto->refcnt, 2);
build_protos(proto->prots[ip_ver], prot);
build_proto_ops(proto->proto_ops[ip_ver],
@@ -1099,9 +1114,15 @@ static int tls_init(struct sock *sk)
if (!proto)
return -ENOMEM;
+ if (!try_module_get(proto->ops->owner)) {
+ refcount_dec(&proto->refcnt);
+ return -ENOENT;
+ }
+
#ifdef CONFIG_TLS_TOE
if (tls_toe_bypass(sk, proto)) {
refcount_dec(&proto->refcnt);
+ module_put(proto->ops->owner);
return 0;
}
#endif
@@ -1114,6 +1135,7 @@ static int tls_init(struct sock *sk)
*/
if (sk->sk_state != TCP_ESTABLISHED) {
refcount_dec(&proto->refcnt);
+ module_put(proto->ops->owner);
return -ENOTCONN;
}
@@ -1122,6 +1144,7 @@ static int tls_init(struct sock *sk)
ctx = tls_ctx_create(sk, proto);
if (!ctx) {
refcount_dec(&proto->refcnt);
+ module_put(proto->ops->owner);
rc = -ENOMEM;
goto out;
}
--
2.51.0