[PATCH mptcp-next v16 03/16] tls: add tls_prot_ops pointer to tls_proto

Geliang Tang posted 16 patches 1 month, 3 weeks ago
There is a newer version of this series
[PATCH mptcp-next v16 03/16] tls: add tls_prot_ops pointer to tls_proto
Posted by Geliang Tang 1 month, 3 weeks ago
From: Geliang Tang <tanggeliang@kylinos.cn>

A pointer to struct tls_prot_ops, named 'ops', has been added to struct
tls_proto.

In tls_build_proto(), proto->ops is assigned either 'tls_mptcp_ops' or
'tls_tcp_ops' based on the socket protocol. Fix module reference counting
bug where each socket release called module_put() without matching get
for existing tls_proto.

Co-developed-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Geliang Tang <tanggeliang@kylinos.cn>
---
 include/net/tls.h  |  1 +
 net/tls/tls_main.c | 25 ++++++++++++++++++++++++-
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/include/net/tls.h b/include/net/tls.h
index 0865932d8cc7..ee24f9d24324 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -249,6 +249,7 @@ struct tls_proto {
 	refcount_t			refcnt;
 	struct list_head		list;
 	const struct proto		*prot;
+	const struct tls_prot_ops	*ops;
 	struct proto prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 	struct proto_ops proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 };
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 1ee891405cde..68308a42899b 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -145,6 +145,7 @@ static void tls_proto_cleanup(void)
 		if (refcount_dec_and_test(&prot->refcnt)) {
 			list_del_rcu(&prot->list);
 			synchronize_rcu();
+			module_put(prot->ops->owner);
 			kfree(prot);
 		}
 	}
@@ -367,9 +368,11 @@ void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
 		return;
 
 	if (ctx->proto) {
+		module_put(ctx->proto->ops->owner);
 		if (refcount_dec_and_test(&ctx->proto->refcnt)) {
 			list_del_rcu(&ctx->proto->list);
 			synchronize_rcu();
+			module_put(ctx->proto->ops->owner);
 			kfree(ctx->proto);
 		}
 	}
@@ -1021,6 +1024,7 @@ static struct tls_proto *tls_build_proto(struct sock *sk)
 {
 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 	struct proto *prot = READ_ONCE(sk->sk_prot);
+	struct tls_prot_ops *ops;
 	struct tls_proto *proto;
 
 	mutex_lock(&tls_proto_mutex);
@@ -1028,11 +1032,22 @@ static struct tls_proto *tls_build_proto(struct sock *sk)
 	if (proto)
 		goto out;
 
+	rcu_read_lock();
+	ops = tls_prot_ops_find(sk->sk_protocol);
+	if (!ops || !try_module_get(ops->owner)) {
+		rcu_read_unlock();
+		goto out;
+	}
+	rcu_read_unlock();
+
 	proto = kzalloc(sizeof(*proto), GFP_KERNEL);
-	if (!proto)
+	if (!proto) {
+		module_put(ops->owner);
 		goto out;
+	}
 
 	proto->prot = prot;
+	proto->ops = ops;
 	refcount_set(&proto->refcnt, 2);
 	build_protos(proto->prots[ip_ver], prot);
 	build_proto_ops(proto->proto_ops[ip_ver],
@@ -1099,9 +1114,15 @@ static int tls_init(struct sock *sk)
 	if (!proto)
 		return -ENOMEM;
 
+	if (!try_module_get(proto->ops->owner)) {
+		refcount_dec(&proto->refcnt);
+		return -ENOENT;
+	}
+
 #ifdef CONFIG_TLS_TOE
 	if (tls_toe_bypass(sk, proto)) {
 		refcount_dec(&proto->refcnt);
+		module_put(proto->ops->owner);
 		return 0;
 	}
 #endif
@@ -1114,6 +1135,7 @@ static int tls_init(struct sock *sk)
 	 */
 	if (sk->sk_state != TCP_ESTABLISHED) {
 		refcount_dec(&proto->refcnt);
+		module_put(proto->ops->owner);
 		return -ENOTCONN;
 	}
 
@@ -1122,6 +1144,7 @@ static int tls_init(struct sock *sk)
 	ctx = tls_ctx_create(sk, proto);
 	if (!ctx) {
 		refcount_dec(&proto->refcnt);
+		module_put(proto->ops->owner);
 		rc = -ENOMEM;
 		goto out;
 	}
-- 
2.51.0