From: Geliang Tang <tanggeliang@kylinos.cn>
Extend TLS subsystem to support MPTCP protocol by implementing
MPTCP-specific versions of key operations:
- mptcp_sendmsg_locked() for TLS record transmission;
- mptcp_inq_hint() and mptcp_recv_skb() for receive side handling;
- mptcp_read_sock() and mptcp_read_done() for data reading;
- mptcp_disconnect() for disconnect.
Signed-off-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Geliang Tang <tanggeliang@kylinos.cn>
---
include/net/mptcp.h | 42 ++++++++++++++++++++++++++++
net/mptcp/protocol.c | 66 ++++++++++++++++++++++++++++++++++++--------
net/tls/tls_main.c | 6 +++-
net/tls/tls_strp.c | 20 +++++++++++---
4 files changed, 118 insertions(+), 16 deletions(-)
diff --git a/include/net/mptcp.h b/include/net/mptcp.h
index 4cf59e83c1c5..ffbbeb08a8be 100644
--- a/include/net/mptcp.h
+++ b/include/net/mptcp.h
@@ -237,6 +237,19 @@ static inline __be32 mptcp_reset_option(const struct sk_buff *skb)
}
void mptcp_active_detect_blackhole(struct sock *sk, bool expired);
+
+int mptcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len);
+
+unsigned int mptcp_inq_hint(const struct sock *sk);
+
+struct sk_buff *mptcp_recv_skb(struct sock *sk, u32 *off);
+
+int mptcp_read_sock(struct sock *sk, read_descriptor_t *desc,
+ sk_read_actor_t recv_actor);
+
+void mptcp_read_done(struct sock *sk, size_t len);
+
+int mptcp_disconnect(struct sock *sk, int flags);
#else
static inline void mptcp_init(void)
@@ -323,6 +336,35 @@ static inline struct request_sock *mptcp_subflow_reqsk_alloc(const struct reques
static inline __be32 mptcp_reset_option(const struct sk_buff *skb) { return htonl(0u); }
static inline void mptcp_active_detect_blackhole(struct sock *sk, bool expired) { }
+
+static inline int mptcp_sendmsg_locked(struct sock *sk, struct msghdr *msg,
+ size_t len)
+{
+ return 0;
+}
+
+static inline unsigned int mptcp_inq_hint(const struct sock *sk)
+{
+ return 0;
+}
+
+static inline struct sk_buff *mptcp_recv_skb(struct sock *sk, u32 *off)
+{
+ return NULL;
+}
+
+static inline int mptcp_read_sock(struct sock *sk, read_descriptor_t *desc,
+ sk_read_actor_t recv_actor)
+{
+ return 0;
+}
+
+static inline void mptcp_read_done(struct sock *sk, size_t len) { }
+
+static inline int mptcp_disconnect(struct sock *sk, int flags)
+{
+ return 0;
+}
#endif /* CONFIG_MPTCP */
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index b31724523ed5..e5e2ba1cd976 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -1752,8 +1752,6 @@ static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk, bool
}
}
-static int mptcp_disconnect(struct sock *sk, int flags);
-
static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
size_t len, int *copied_syn)
{
@@ -1862,7 +1860,7 @@ static void mptcp_rps_record_subflows(const struct mptcp_sock *msk)
}
}
-static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+int mptcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len)
{
struct mptcp_sock *msk = mptcp_sk(sk);
struct page_frag *pfrag;
@@ -1873,8 +1871,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
/* silently ignore everything else */
msg->msg_flags &= MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_FASTOPEN;
- lock_sock(sk);
-
mptcp_rps_record_subflows(msk);
if (unlikely(inet_test_bit(DEFER_CONNECT, sk) ||
@@ -1982,7 +1978,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
__mptcp_push_pending(sk, msg->msg_flags);
out:
- release_sock(sk);
return copied;
do_error:
@@ -1993,6 +1988,17 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
goto out;
}
+static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+{
+ int ret;
+
+ lock_sock(sk);
+ ret = mptcp_sendmsg_locked(sk, msg, len);
+ release_sock(sk);
+
+ return ret;
+}
+
static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied);
static void mptcp_eat_recv_skb(struct sock *sk, struct sk_buff *skb)
@@ -2224,7 +2230,7 @@ static bool mptcp_move_skbs(struct sock *sk)
return enqueued;
}
-static unsigned int mptcp_inq_hint(const struct sock *sk)
+unsigned int mptcp_inq_hint(const struct sock *sk)
{
const struct mptcp_sock *msk = mptcp_sk(sk);
const struct sk_buff *skb;
@@ -3329,7 +3335,7 @@ static void mptcp_destroy_common(struct mptcp_sock *msk)
mptcp_pm_destroy(msk);
}
-static int mptcp_disconnect(struct sock *sk, int flags)
+int mptcp_disconnect(struct sock *sk, int flags)
{
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -4271,7 +4277,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
return mask;
}
-static struct sk_buff *mptcp_recv_skb(struct sock *sk, u32 *off)
+struct sk_buff *mptcp_recv_skb(struct sock *sk, u32 *off)
{
struct mptcp_sock *msk = mptcp_sk(sk);
struct sk_buff *skb;
@@ -4295,8 +4301,8 @@ static struct sk_buff *mptcp_recv_skb(struct sock *sk, u32 *off)
* Note:
* - It is assumed that the socket was locked by the caller.
*/
-static int mptcp_read_sock(struct sock *sk, read_descriptor_t *desc,
- sk_read_actor_t recv_actor)
+int mptcp_read_sock(struct sock *sk, read_descriptor_t *desc,
+ sk_read_actor_t recv_actor)
{
struct mptcp_sock *msk = mptcp_sk(sk);
size_t len = sk->sk_rcvbuf;
@@ -4453,6 +4459,44 @@ static ssize_t mptcp_splice_read(struct socket *sock, loff_t *ppos,
return ret;
}
+void mptcp_read_done(struct sock *sk, size_t len)
+{
+ struct mptcp_sock *msk = mptcp_sk(sk);
+ struct sk_buff *skb;
+ size_t left;
+ u32 offset;
+
+ msk_owned_by_me(msk);
+
+ if (sk->sk_state == TCP_LISTEN)
+ return;
+
+ left = len;
+ while (left && (skb = mptcp_recv_skb(sk, &offset)) != NULL) {
+ int used;
+
+ used = min_t(size_t, skb->len - offset, left);
+ left -= used;
+ msk->bytes_consumed += used;
+ MPTCP_SKB_CB(skb)->offset += used;
+ MPTCP_SKB_CB(skb)->map_seq += used;
+
+ if (skb->len > offset + used)
+ break;
+
+ mptcp_eat_recv_skb(sk, skb);
+ }
+
+ mptcp_rcv_space_adjust(msk, len - left);
+
+ /* Clean up data we have read: This will do ACK frames. */
+ if (left != len) {
+ mptcp_recv_skb(sk, &offset);
+ mptcp_cleanup_rbuf(msk, len - left);
+ }
+}
+EXPORT_SYMBOL(mptcp_read_done);
+
static const struct proto_ops mptcp_stream_ops = {
.family = PF_INET,
.owner = THIS_MODULE,
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 56ce0bc8317b..7d7bde1702c1 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -194,7 +194,9 @@ int tls_push_sg(struct sock *sk,
bvec_set_page(&bvec, p, size, offset);
iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
- ret = tcp_sendmsg_locked(sk, &msg, size);
+ ret = sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_sendmsg_locked(sk, &msg, size) :
+ tcp_sendmsg_locked(sk, &msg, size);
if (ret != size) {
if (ret > 0) {
@@ -907,6 +909,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
static int tls_disconnect(struct sock *sk, int flags)
{
+ if (sk->sk_protocol == IPPROTO_MPTCP)
+ return mptcp_disconnect(sk, flags);
return -EOPNOTSUPP;
}
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 98e12f0ff57e..3985e77f3351 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -132,6 +132,8 @@ int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
tls_strp_anchor_free(strp);
strp->anchor = skb;
+ strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_read_done(strp->sk, strp->stm.full_len) :
tcp_read_done(strp->sk, strp->stm.full_len);
strp->copy_mode = 1;
@@ -383,6 +385,8 @@ static int tls_strp_read_copyin(struct tls_strparser *strp)
desc.count = 1; /* give more than one skb per call */
/* sk should be locked here, so okay to do read_sock */
+ strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_read_sock(strp->sk, &desc, tls_strp_copyin) :
tcp_read_sock(strp->sk, &desc, tls_strp_copyin);
return desc.error;
@@ -464,8 +468,10 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
struct sk_buff *first;
u32 offset;
- first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
- if (WARN_ON_ONCE(!first))
+ first = strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_recv_skb(strp->sk, &offset) :
+ tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
+ if (!first)
return;
/* Bestow the state onto the anchor */
@@ -490,7 +496,9 @@ bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
if (!strp->copy_mode && force_refresh) {
- if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) {
+ if (unlikely((strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_inq_hint(strp->sk) :
+ tcp_inq(strp->sk)) < strp->stm.full_len)) {
WRITE_ONCE(strp->msg_ready, 0);
memset(&strp->stm, 0, sizeof(strp->stm));
return false;
@@ -513,7 +521,9 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
{
int sz, inq;
- inq = tcp_inq(strp->sk);
+ inq = strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_inq_hint(strp->sk) :
+ tcp_inq(strp->sk);
if (inq < 1)
return 0;
@@ -586,6 +596,8 @@ void tls_strp_msg_done(struct tls_strparser *strp)
WARN_ON(!strp->stm.full_len);
if (likely(!strp->copy_mode))
+ strp->sk->sk_protocol == IPPROTO_MPTCP ?
+ mptcp_read_done(strp->sk, strp->stm.full_len) :
tcp_read_done(strp->sk, strp->stm.full_len);
else
tls_strp_flush_anchor_copy(strp);
--
2.51.0
© 2016 - 2025 Red Hat, Inc.