[PATCH net] net/tls: preserve sk_msg sg.copy when splitting records

Yiming Qian posted 1 patch 3 days, 17 hours ago
include/linux/skmsg.h | 9 +++++++++
net/tls/tls_sw.c      | 7 +++++++
2 files changed, 16 insertions(+)
[PATCH net] net/tls: preserve sk_msg sg.copy when splitting records
Posted by Yiming Qian 3 days, 17 hours ago
tls_split_open_record() copies scatterlist entries from the current
plaintext sk_msg into a newly allocated plaintext sk_msg when an open
record is split.

The scatterlist entry and the corresponding msg->sg.copy bit are one
ownership record. Splice-backed entries are created by sk_msg_page_add()
with the copy bit set so sk_msg_compute_data_pointers() does not expose
them as writable BPF msg->data.

The split path used memcpy() to copy both partial and whole tail entries
but left the new sk_msg copy bitmap clear. A subsequent SK_MSG verdict on
the split tail could therefore receive a writable data pointer to a page
that was only supposed to be copied, allowing BPF to overwrite externally
owned page cache.

Add a helper for copying one sg.copy bit and use it for the partial tmp
entry and for each copied tail entry.

Fixes: d3b18ad31f93 ("tls: add bpf support to sk_msg handling")
Reported-by: Yiming Qian <yimingqian591@gmail.com>
Reported-by: Keenan Dong <keenanat2000@gmail.com>
Signed-off-by: Yiming Qian <yimingqian591@gmail.com>
Signed-off-by: Keenan Dong <keenanat2000@gmail.com>
---
 include/linux/skmsg.h | 9 +++++++++
 net/tls/tls_sw.c      | 7 +++++++
 2 files changed, 16 insertions(+)

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 19f4f253b4f90..f3988ce2219db 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -283,6 +283,15 @@ static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
 	} while (1);
 }
 
+static inline void sk_msg_sg_copy_one(struct sk_msg *dst, u32 dst_i,
+				      const struct sk_msg *src, u32 src_i)
+{
+	if (test_bit(src_i, src->sg.copy))
+		__set_bit(dst_i, dst->sg.copy);
+	else
+		__clear_bit(dst_i, dst->sg.copy);
+}
+
 static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
 {
 	sk_msg_sg_copy(msg, start, true);
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 964ebc268ee46..434753de8aadd 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -623,6 +623,7 @@ static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
 	struct scatterlist *sge, *osge, *nsge;
 	u32 orig_size = msg_opl->sg.size;
 	struct scatterlist tmp = { };
+	u32 tmp_i = NR_MSG_FRAG_IDS;
 	struct sk_msg *msg_npl;
 	struct tls_rec *new;
 	int ret;
@@ -644,6 +645,7 @@ static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
 		if (sge->length > apply) {
 			u32 len = sge->length - apply;
 
+			tmp_i = i;
 			get_page(sg_page(sge));
 			sg_set_page(&tmp, sg_page(sge), len,
 				    sge->offset + apply);
@@ -675,6 +677,10 @@ static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
 	nsge = sk_msg_elem(msg_npl, j);
 	if (tmp.length) {
 		memcpy(nsge, &tmp, sizeof(*nsge));
+		if (WARN_ON_ONCE(tmp_i == NR_MSG_FRAG_IDS))
+			__clear_bit(j, msg_npl->sg.copy);
+		else
+			sk_msg_sg_copy_one(msg_npl, j, msg_opl, tmp_i);
 		sk_msg_iter_var_next(j);
 		nsge = sk_msg_elem(msg_npl, j);
 	}
@@ -682,6 +688,7 @@ static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
 	osge = sk_msg_elem(msg_opl, i);
 	while (osge->length) {
 		memcpy(nsge, osge, sizeof(*nsge));
+		sk_msg_sg_copy_one(msg_npl, j, msg_opl, i);
 		sg_unmark_end(nsge);
 		sk_msg_iter_var_next(i);
 		sk_msg_iter_var_next(j);