[PATCH net-next RFC 3/3] netlink: convert to getsockopt_iter

Breno Leitao posted 3 patches 1 week, 2 days ago
[PATCH net-next RFC 3/3] netlink: convert to getsockopt_iter
Posted by Breno Leitao 1 week, 2 days ago
Convert netlink's getsockopt implementation to use the new
getsockopt_iter callback with sockopt_t.

Key changes:
- Replace (char __user *optval, int __user *optlen) with sockopt_t *opt
- Use opt->optlen for buffer length (input) and returned size (output)
- Use copy_to_iter() instead of put_user()/copy_to_user()

The optlen field allows callbacks to return a specific length value
independent of the bytes written via copy_to_iter().

This enables io_uring to call netlink's getsockopt with kernel buffers.

Signed-off-by: Breno Leitao <leitao@debian.org>
---
 net/netlink/af_netlink.c | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 8e5151f0c6e46..8a195eb1ef761 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -39,6 +39,7 @@
 #include <linux/fs.h>
 #include <linux/slab.h>
 #include <linux/uaccess.h>
+#include <linux/uio.h>
 #include <linux/skbuff.h>
 #include <linux/netdevice.h>
 #include <linux/rtnetlink.h>
@@ -1716,7 +1717,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
 }
 
 static int netlink_getsockopt(struct socket *sock, int level, int optname,
-			      char __user *optval, int __user *optlen)
+			      sockopt_t *opt)
 {
 	struct sock *sk = sock->sk;
 	struct netlink_sock *nlk = nlk_sk(sk);
@@ -1726,8 +1727,7 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
 	if (level != SOL_NETLINK)
 		return -ENOPROTOOPT;
 
-	if (get_user(len, optlen))
-		return -EFAULT;
+	len = opt->optlen;
 	if (len < 0)
 		return -EINVAL;
 
@@ -1743,6 +1743,8 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
 		break;
 	case NETLINK_LIST_MEMBERSHIPS: {
 		int pos, idx, shift, err = 0;
+		u32 group_val;
+		size_t size;
 
 		netlink_lock_table();
 		for (pos = 0; pos * 8 < nlk->ngroups; pos += sizeof(u32)) {
@@ -1751,14 +1753,14 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
 
 			idx = pos / sizeof(unsigned long);
 			shift = (pos % sizeof(unsigned long)) * 8;
-			if (put_user((u32)(nlk->groups[idx] >> shift),
-				     (u32 __user *)(optval + pos))) {
+			group_val = (u32)(nlk->groups[idx] >> shift);
+			size = copy_to_iter(&group_val, sizeof(group_val), &opt->iter);
+			if (size != sizeof(group_val)) {
 				err = -EFAULT;
 				break;
 			}
 		}
-		if (put_user(ALIGN(BITS_TO_BYTES(nlk->ngroups), sizeof(u32)), optlen))
-			err = -EFAULT;
+		opt->optlen = ALIGN(BITS_TO_BYTES(nlk->ngroups), sizeof(u32));
 		netlink_unlock_table();
 		return err;
 	}
@@ -1784,10 +1786,10 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
 	len = sizeof(int);
 	val = test_bit(flag, &nlk->flags);
 
-	if (put_user(len, optlen) ||
-	    copy_to_user(optval, &val, len))
+	if (copy_to_iter(&val, len, &opt->iter) != len)
 		return -EFAULT;
 
+	opt->optlen = sizeof(int);
 	return 0;
 }
 
@@ -2813,7 +2815,7 @@ static const struct proto_ops netlink_ops = {
 	.listen =	sock_no_listen,
 	.shutdown =	sock_no_shutdown,
 	.setsockopt =	netlink_setsockopt,
-	.getsockopt =	netlink_getsockopt,
+	.getsockopt_iter =	netlink_getsockopt,
 	.sendmsg =	netlink_sendmsg,
 	.recvmsg =	netlink_recvmsg,
 	.mmap =		sock_no_mmap,

-- 
2.47.3