From nobody Tue Feb 10 00:57:41 2026 Return-Path: X-Spam-Checker-Version: SpamAssassin 3.4.0 (2014-02-07) on aws-us-west-2-korg-lkml-1.web.codeaurora.org Received: from vger.kernel.org (vger.kernel.org [23.128.96.18]) by smtp.lore.kernel.org (Postfix) with ESMTP id B37ECC7EE2C for ; Wed, 31 May 2023 00:36:14 +0000 (UTC) Received: (majordomo@vger.kernel.org) by vger.kernel.org via listexpand id S233592AbjEaAgN (ORCPT ); Tue, 30 May 2023 20:36:13 -0400 Received: from lindbergh.monkeyblade.net ([23.128.96.19]:51366 "EHLO lindbergh.monkeyblade.net" rhost-flags-OK-OK-OK-OK) by vger.kernel.org with ESMTP id S233557AbjEaAgI (ORCPT ); Tue, 30 May 2023 20:36:08 -0400 Received: from mail-pf1-x433.google.com (mail-pf1-x433.google.com [IPv6:2607:f8b0:4864:20::433]) by lindbergh.monkeyblade.net (Postfix) with ESMTPS id 2B833E63 for ; Tue, 30 May 2023 17:35:38 -0700 (PDT) Received: by mail-pf1-x433.google.com with SMTP id d2e1a72fcca58-64f47448aeaso3781467b3a.0 for ; Tue, 30 May 2023 17:35:38 -0700 (PDT) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=bytedance.com; s=google; t=1685493317; x=1688085317; h=cc:to:in-reply-to:references:message-id:content-transfer-encoding :mime-version:subject:date:from:from:to:cc:subject:date:message-id :reply-to; bh=n13sK70ntf9lmsIk7TA5LIhrsOX2zMixR6tdDTqVnfI=; b=GGa6NhUxFPSkiOUWito71znmu//gL3kL8ORXvHp8USRfWMwCexnsjTgGJIN9YC8Qtm Ct/uSkfzCQXbD2FvCSzucXSYMkbeYZNbrErhOC0K85Mt1pYfzgvu8P0/57mh9ERjGz7P pCoCV0WUeQd6mZUE4TxexoYNNauDXsmFDinxEabVf+smoSrKZ+ovXq7wabqZtrSBxMv2 3NbFDW7nxjGWs4GyofIHMA5LbLsgib+HUKCqGftiUzHF/RCEVjRMNkAqclqV5+PsObDR /Q6lAYWmPKJlz/czyMLEaBcQ/2+63572/yiOXedI8KSVOQx3eueh6WVMPCZuCUpgMTqc 12TQ== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20221208; t=1685493317; x=1688085317; h=cc:to:in-reply-to:references:message-id:content-transfer-encoding :mime-version:subject:date:from:x-gm-message-state:from:to:cc :subject:date:message-id:reply-to; bh=n13sK70ntf9lmsIk7TA5LIhrsOX2zMixR6tdDTqVnfI=; b=J2NirMgx5JcHRNi1cRnFIa5F92aDT3bCqrhUCEwcpVltIPo1Z/Y1FLAuDshPvJGMxF /ZIBtRcW9DJ6hNFdJry+RphVS8q+PV3jcYpnhk72LyWXq9L18x/zk3i5q0Nx3kpxL/Qr Vlz4vd3OKQyyGyYrhn07bT/P/C4VlJOtWR6S9Zg2TwqrsgsJZwUMEnQeMzI9t/JqBDFr uXMXihJ7iKCDj/DlgHU3WaWSIoqXN5zFh0POkF4Q2ft98Ol5YCZoMwQqDMcWIiNx48gi cGLhy8Wg/qLJCvZ11IIC8R9oRh/shuZnscajmSyL4l87dgPklNlIPRFXJdz7CswGK2Xe oF3w== X-Gm-Message-State: AC+VfDy/7OXrd+9NQ5d+Bw4w4s91TSerwHyzMViOjWHGDQvOP/GT1lgc M1TQjLUlbQ6ymKgc+QRpFEozHA== X-Google-Smtp-Source: ACHHUZ7yBqpLlJYxHExESv2T9834Ba8BVZoJbnwk9DOein9dgvHBbM0Oz8xkrbwUWQt4jx6NMThoyA== X-Received: by 2002:a05:6a00:896:b0:650:154:8ac with SMTP id q22-20020a056a00089600b00650015408acmr3689531pfj.3.1685493316721; Tue, 30 May 2023 17:35:16 -0700 (PDT) Received: from [172.17.0.2] (c-67-170-131-147.hsd1.wa.comcast.net. [67.170.131.147]) by smtp.gmail.com with ESMTPSA id j12-20020a62b60c000000b0064cb0845c77sm2151340pff.122.2023.05.30.17.35.15 (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Tue, 30 May 2023 17:35:16 -0700 (PDT) From: Bobby Eshleman Date: Wed, 31 May 2023 00:35:11 +0000 Subject: [PATCH RFC net-next v3 7/8] vsock: Add lockless sendmsg() support MIME-Version: 1.0 Content-Type: text/plain; charset="utf-8" Content-Transfer-Encoding: quoted-printable Message-Id: <20230413-b4-vsock-dgram-v3-7-c2414413ef6a@bytedance.com> References: <20230413-b4-vsock-dgram-v3-0-c2414413ef6a@bytedance.com> In-Reply-To: <20230413-b4-vsock-dgram-v3-0-c2414413ef6a@bytedance.com> To: Stefan Hajnoczi , Stefano Garzarella , "Michael S. Tsirkin" , Jason Wang , "David S. Miller" , Eric Dumazet , Jakub Kicinski , Paolo Abeni , "K. Y. Srinivasan" , Haiyang Zhang , Wei Liu , Dexuan Cui , Bryan Tan , Vishnu Dasa , VMware PV-Drivers Reviewers Cc: kvm@vger.kernel.org, virtualization@lists.linux-foundation.org, netdev@vger.kernel.org, linux-kernel@vger.kernel.org, linux-hyperv@vger.kernel.org, Bobby Eshleman X-Mailer: b4 0.12.2 Precedence: bulk List-ID: X-Mailing-List: linux-kernel@vger.kernel.org Because the dgram sendmsg() path for AF_VSOCK acquires the socket lock it does not scale when many senders share a socket. Prior to this patch the socket lock is used to protect both reads and writes to the local_addr, remote_addr, transport, and buffer size variables of a vsock socket. What follows are the new protection schemes for these fields that ensure a race-free and usually lock-free multi-sender sendmsg() path for vsock dgrams. - local_addr local_addr changes as a result of binding a socket. The write path for local_addr is bind() and various vsock_auto_bind() call sites. After a socket has been bound via vsock_auto_bind() or bind(), subseque= nt calls to bind()/vsock_auto_bind() do not write to local_addr again. bin= d() rejects the user request and vsock_auto_bind() early exits. Therefore, the local addr can not change while a parallel thread is in sendmsg() and lock-free reads of local addr in sendmsg() are safe. Change: only acquire lock for auto-binding as-needed in sendmsg(). - buffer size variables Not used by dgram, so they do not need protection. No change. - remote_addr and transport Because a remote_addr update may result in a changed transport, but we would like to be able to read these two fields lock-free but coherently in the vsock send path, this patch packages these two fields into a new struct vsock_remote_info that is referenced by an RCU-protected pointer. Writes are synchronized as usual by the socket lock. Reads only take place in RCU read-side critical sections. When remote_addr or transport is updated, a new remote info is allocated. Old readers still see the old coherent remote_addr/transport pair, and new readers will refer to the new coherent. The coherency between remote_addr and transport previously provided by the socket lock alone is now also preserved by RCU, except with the highly-scalable lock-free read-side. Helpers are introduced for accessing and updating the new pointer. The new structure is contains an rcu_head so that kfree_rcu() can be used. This removes the need of writers to use synchronize_rcu() after freeing old structures which is simply more efficient and reduces code churn where remote_addr/transport are already being updated inside RCU read-side sections. Only virtio has been tested, but updates were necessary to the VMCI and hyperv code. Unfortunately the author does not have access to VMCI/hyperv systems so those changes are untested. Perf Tests (results from patch v2) vCPUS: 16 Threads: 16 Payload: 4KB Test Runs: 5 Type: SOCK_DGRAM Before: 245.2 MB/s After: 509.2 MB/s (+107%) Notably, on the same test system, vsock dgram even outperforms multi-threaded UDP over virtio-net with vhost and MQ support enabled. Throughput metrics for single-threaded SOCK_DGRAM and single/multi-threaded SOCK_STREAM showed no statistically signficant throughput changes (lowest p-value reaching 0.27), with the range of the mean difference ranging between -5% to +1%. Signed-off-by: Bobby Eshleman --- drivers/vhost/vsock.c | 12 +- include/linux/virtio_vsock.h | 3 +- include/net/af_vsock.h | 39 ++- net/vmw_vsock/af_vsock.c | 451 +++++++++++++++++++++++++---= ---- net/vmw_vsock/diag.c | 10 +- net/vmw_vsock/hyperv_transport.c | 27 +- net/vmw_vsock/virtio_transport_common.c | 32 ++- net/vmw_vsock/vmci_transport.c | 84 ++++-- net/vmw_vsock/vsock_bpf.c | 10 +- 9 files changed, 518 insertions(+), 150 deletions(-) diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 159c1a22c1a8..b027a780d333 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -297,13 +297,17 @@ static int vhost_transport_cancel_pkt(struct vsock_sock *vsk) { struct vhost_vsock *vsock; + unsigned int cid; int cnt =3D 0; int ret =3D -ENODEV; =20 rcu_read_lock(); + ret =3D vsock_remote_addr_cid(vsk, &cid); + if (ret < 0) + goto out; =20 /* Find the vhost_vsock according to guest context id */ - vsock =3D vhost_vsock_get(vsk->remote_addr.svm_cid); + vsock =3D vhost_vsock_get(cid); if (!vsock) goto out; =20 @@ -706,6 +710,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsoc= k) static void vhost_vsock_reset_orphans(struct sock *sk) { struct vsock_sock *vsk =3D vsock_sk(sk); + unsigned int cid; + + if (vsock_remote_addr_cid(vsk, &cid) < 0) + return; =20 /* vmci_transport.c doesn't take sk_lock here either. At least we're * under vsock_table_lock so the sock cannot disappear while we're @@ -713,7 +721,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk) */ =20 /* If the peer is still valid, no need to reset connection */ - if (vhost_vsock_get(vsk->remote_addr.svm_cid)) + if (vhost_vsock_get(cid)) return; =20 /* If the close timeout is pending, let it expire. This avoids races diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index 237ca87a2ecd..97656e83606f 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -231,7 +231,8 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, size_t len); int -virtio_transport_dgram_enqueue(struct vsock_sock *vsk, +virtio_transport_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, size_t len); diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index c115e655b4f5..84f2a9700ebd 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock; #define vsock_sk(__sk) ((struct vsock_sock *)__sk) #define sk_vsock(__vsk) (&(__vsk)->sk) =20 +struct vsock_remote_info { + struct sockaddr_vm addr; + struct rcu_head rcu; + const struct vsock_transport *transport; +}; + struct vsock_sock { /* sk must be the first member. */ struct sock sk; - const struct vsock_transport *transport; struct sockaddr_vm local_addr; - struct sockaddr_vm remote_addr; + struct vsock_remote_info * __rcu remote_info; /* Links for the global tables of bound and connected sockets. */ struct list_head bound_table; struct list_head connected_table; @@ -120,8 +125,8 @@ struct vsock_transport { =20 /* DGRAM. */ int (*dgram_bind)(struct vsock_sock *, struct sockaddr_vm *); - int (*dgram_enqueue)(struct vsock_sock *, struct sockaddr_vm *, - struct msghdr *, size_t len); + int (*dgram_enqueue)(const struct vsock_transport *, struct vsock_sock *, + struct sockaddr_vm *, struct msghdr *, size_t len); bool (*dgram_allow)(u32 cid, u32 port); int (*dgram_get_cid)(struct sk_buff *skb, unsigned int *cid); int (*dgram_get_port)(struct sk_buff *skb, unsigned int *port); @@ -196,6 +201,17 @@ void vsock_core_unregister(const struct vsock_transpor= t *t); /* The transport may downcast this to access transport-specific functions = */ const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *= vsk); =20 +static inline struct vsock_remote_info * +vsock_core_get_remote_info(struct vsock_sock *vsk) +{ + + /* vsk->remote_info may be accessed if the rcu read lock is held OR the + * socket lock is held + */ + return rcu_dereference_check(vsk->remote_info, + lockdep_sock_is_held(sk_vsock(vsk))); +} + /**** UTILS ****/ =20 /* vsock_table_lock must be held */ @@ -214,7 +230,7 @@ void vsock_release_pending(struct sock *pending); void vsock_add_pending(struct sock *listener, struct sock *pending); void vsock_remove_pending(struct sock *listener, struct sock *pending); void vsock_enqueue_accept(struct sock *listener, struct sock *connected); -void vsock_insert_connected(struct vsock_sock *vsk); +int vsock_insert_connected(struct vsock_sock *vsk); void vsock_remove_bound(struct vsock_sock *vsk); void vsock_remove_connected(struct vsock_sock *vsk); struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); @@ -223,7 +239,8 @@ struct sock *vsock_find_connected_socket(struct sockadd= r_vm *src, void vsock_remove_sock(struct vsock_sock *vsk); void vsock_for_each_connected_socket(struct vsock_transport *transport, void (*fn)(struct sock *sk)); -int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk); +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk, + struct sockaddr_vm *remote_addr); bool vsock_find_cid(unsigned int cid); struct sock *vsock_find_bound_dgram_socket(struct sockaddr_vm *addr); =20 @@ -253,4 +270,14 @@ static inline void __init vsock_bpf_build_proto(void) {} #endif =20 +/* RCU-protected remote addr helpers */ +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid); +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port); +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid, + unsigned int *port); +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *des= t); +bool vsock_remote_addr_bound(struct vsock_sock *vsk); +bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *= other); +int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32= port); + #endif /* __AF_VSOCK_H__ */ diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index e8c70069d77d..0520228d2a68 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -114,6 +114,8 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); static void vsock_sk_destruct(struct sock *sk); static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); +static bool vsock_use_local_transport(unsigned int remote_cid); +static bool sock_type_connectible(u16 type); =20 /* Protocol family. */ struct proto vsock_proto =3D { @@ -145,6 +147,147 @@ static const struct vsock_transport *transport_local; static DEFINE_MUTEX(vsock_register_mutex); =20 /**** UTILS ****/ +bool vsock_remote_addr_bound(struct vsock_sock *vsk) +{ + struct vsock_remote_info *remote_info; + bool ret; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return false; + } + + ret =3D vsock_addr_bound(&remote_info->addr); + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_bound); + +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *des= t) +{ + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + memcpy(dest, &remote_info->addr, sizeof(*dest)); + rcu_read_unlock(); + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_copy); + +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid) +{ + return vsock_remote_addr_cid_port(vsk, cid, NULL); +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid); + +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port) +{ + return vsock_remote_addr_cid_port(vsk, NULL, port); +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_port); + +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid, + unsigned int *port) +{ + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + + if (cid) + *cid =3D remote_info->addr.svm_cid; + if (port) + *port =3D remote_info->addr.svm_port; + + rcu_read_unlock(); + return 0; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port); + +/* The socket lock must be held by the caller */ +int vsock_set_remote_info(struct vsock_sock *vsk, + const struct vsock_transport *transport, + struct sockaddr_vm *addr) +{ + struct vsock_remote_info *old, *new; + + if (addr || transport) { + new =3D kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; + + if (addr) + memcpy(&new->addr, addr, sizeof(new->addr)); + + if (transport) + new->transport =3D transport; + } else { + new =3D NULL; + } + + old =3D rcu_replace_pointer(vsk->remote_info, new, lockdep_sock_is_held(s= k_vsock(vsk))); + kfree_rcu(old, rcu); + + return 0; +} + +static const struct vsock_transport * +vsock_connectible_lookup_transport(unsigned int cid, __u8 flags) +{ + const struct vsock_transport *transport; + + if (vsock_use_local_transport(cid)) + transport =3D transport_local; + else if (cid <=3D VMADDR_CID_HOST || !transport_h2g || + (flags & VMADDR_FLAG_TO_HOST)) + transport =3D transport_g2h; + else + transport =3D transport_h2g; + + return transport; +} + +static const struct vsock_transport * +vsock_dgram_lookup_transport(unsigned int cid, __u8 flags) +{ + if (transport_dgram) + return transport_dgram; + + return vsock_connectible_lookup_transport(cid, flags); +} + +bool vsock_remote_addr_equals(struct vsock_sock *vsk, + struct sockaddr_vm *other) +{ + struct vsock_remote_info *remote_info; + bool equals; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return false; + } + + equals =3D vsock_addr_equals_addr(&remote_info->addr, other); + rcu_read_unlock(); + + return equals; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_equals); =20 /* Each bound VSocket is stored in the bind hash table and each connected * VSocket is stored in the connected hash table. @@ -284,10 +427,16 @@ static struct sock *__vsock_find_connected_socket(str= uct sockaddr_vm *src, =20 list_for_each_entry(vsk, vsock_connected_sockets(src, dst), connected_table) { - if (vsock_addr_equals_addr(src, &vsk->remote_addr) && + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (vsock_addr_equals_addr(src, &remote_info->addr) && dst->svm_port =3D=3D vsk->local_addr.svm_port) { + rcu_read_unlock(); return sk_vsock(vsk); } + rcu_read_unlock(); } =20 return NULL; @@ -300,17 +449,36 @@ static void vsock_insert_unbound(struct vsock_sock *v= sk) spin_unlock_bh(&vsock_table_lock); } =20 -void vsock_insert_connected(struct vsock_sock *vsk) +int vsock_insert_connected(struct vsock_sock *vsk) { - struct list_head *list =3D vsock_connected_sockets( - &vsk->remote_addr, &vsk->local_addr); + struct list_head *list; + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + list =3D vsock_connected_sockets(&remote_info->addr, &vsk->local_addr); + rcu_read_unlock(); =20 spin_lock_bh(&vsock_table_lock); __vsock_insert_connected(list, vsk); spin_unlock_bh(&vsock_table_lock); + + return 0; } EXPORT_SYMBOL_GPL(vsock_insert_connected); =20 +void vsock_remove_dgram_bound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_dgram_table_lock); + if (__vsock_in_bound_table(vsk)) + __vsock_remove_bound(vsk); + spin_unlock_bh(&vsock_dgram_table_lock); +} + void vsock_remove_bound(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); @@ -362,7 +530,10 @@ EXPORT_SYMBOL_GPL(vsock_find_connected_socket); =20 void vsock_remove_sock(struct vsock_sock *vsk) { - vsock_remove_bound(vsk); + if (sock_type_connectible(sk_vsock(vsk)->sk_type)) + vsock_remove_bound(vsk); + else + vsock_remove_dgram_bound(vsk); vsock_remove_connected(vsk); } EXPORT_SYMBOL_GPL(vsock_remove_sock); @@ -378,7 +549,7 @@ void vsock_for_each_connected_socket(struct vsock_trans= port *transport, struct vsock_sock *vsk; list_for_each_entry(vsk, &vsock_connected_table[i], connected_table) { - if (vsk->transport !=3D transport) + if (vsock_core_get_transport(vsk) !=3D transport) continue; =20 fn(sk_vsock(vsk)); @@ -444,59 +615,39 @@ static bool vsock_use_local_transport(unsigned int re= mote_cid) =20 static void vsock_deassign_transport(struct vsock_sock *vsk) { - if (!vsk->transport) - return; - - vsk->transport->destruct(vsk); - module_put(vsk->transport->module); - vsk->transport =3D NULL; -} - -static const struct vsock_transport * -vsock_connectible_lookup_transport(unsigned int cid, __u8 flags) -{ - const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; =20 - if (vsock_use_local_transport(cid)) - transport =3D transport_local; - else if (cid <=3D VMADDR_CID_HOST || !transport_h2g || - (flags & VMADDR_FLAG_TO_HOST)) - transport =3D transport_g2h; - else - transport =3D transport_h2g; - - return transport; -} - -static const struct vsock_transport * -vsock_dgram_lookup_transport(unsigned int cid, __u8 flags) -{ - if (transport_dgram) - return transport_dgram; + remote_info =3D rcu_replace_pointer(vsk->remote_info, NULL, + lockdep_sock_is_held(sk_vsock(vsk))); + if (!remote_info) + return; =20 - return vsock_connectible_lookup_transport(cid, flags); + remote_info->transport->destruct(vsk); + module_put(remote_info->transport->module); + kfree_rcu(remote_info, rcu); } =20 /* Assign a transport to a socket and call the .init transport callback. * - * Note: for connection oriented socket this must be called when vsk->remo= te_addr - * is set (e.g. during the connect() or when a connection request on a lis= tener - * socket is received). - * The vsk->remote_addr is used to decide which transport to use: + * The remote_addr is used to decide which transport to use: * - remote CID =3D=3D VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_H= OST if * g2h is not loaded, will use local transport; * - remote CID <=3D VMADDR_CID_HOST or h2g is not loaded or remote flags= field * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transp= ort; * - remote CID > VMADDR_CID_HOST will use host->guest transport; */ -int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk, + struct sockaddr_vm *remote_addr) { const struct vsock_transport *new_transport; + struct vsock_remote_info *old_info; struct sock *sk =3D sk_vsock(vsk); - unsigned int remote_cid =3D vsk->remote_addr.svm_cid; + unsigned int remote_cid; __u8 remote_flags; int ret; =20 + remote_cid =3D remote_addr->svm_cid; + /* If the packet is coming with the source and destination CIDs higher * than VMADDR_CID_HOST, then a vsock channel where all the packets are * forwarded to the host should be established. Then the host will @@ -506,10 +657,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, st= ruct vsock_sock *psk) * the connect path the flag can be set by the user space application. */ if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST && - vsk->remote_addr.svm_cid > VMADDR_CID_HOST) - vsk->remote_addr.svm_flags |=3D VMADDR_FLAG_TO_HOST; + remote_cid > VMADDR_CID_HOST) + remote_addr->svm_flags |=3D VMADDR_FLAG_TO_HOST; =20 - remote_flags =3D vsk->remote_addr.svm_flags; + remote_flags =3D remote_addr->svm_flags; =20 switch (sk->sk_type) { case SOCK_DGRAM: @@ -525,8 +676,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, stru= ct vsock_sock *psk) return -ESOCKTNOSUPPORT; } =20 - if (vsk->transport) { - if (vsk->transport =3D=3D new_transport) + old_info =3D vsock_core_get_remote_info(vsk); + if (old_info && old_info->transport) { + if (old_info->transport =3D=3D new_transport) return 0; =20 /* transport->release() must be called with sock lock acquired. @@ -535,7 +687,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, stru= ct vsock_sock *psk) * function is called on a new socket which is not assigned to * any transport. */ - vsk->transport->release(vsk); + old_info->transport->release(vsk); vsock_deassign_transport(vsk); } =20 @@ -553,13 +705,18 @@ int vsock_assign_transport(struct vsock_sock *vsk, st= ruct vsock_sock *psk) } } =20 - ret =3D new_transport->init(vsk, psk); + ret =3D vsock_set_remote_info(vsk, new_transport, remote_addr); if (ret) { module_put(new_transport->module); return ret; } =20 - vsk->transport =3D new_transport; + ret =3D new_transport->init(vsk, psk); + if (ret) { + vsock_set_remote_info(vsk, NULL, NULL); + module_put(new_transport->module); + return ret; + } =20 return 0; } @@ -616,12 +773,14 @@ static bool vsock_is_pending(struct sock *sk) =20 static int vsock_send_shutdown(struct sock *sk, int mode) { + const struct vsock_transport *transport; struct vsock_sock *vsk =3D vsock_sk(sk); =20 - if (!vsk->transport) + transport =3D vsock_core_get_transport(vsk); + if (!transport) return -ENODEV; =20 - return vsk->transport->shutdown(vsk, mode); + return transport->shutdown(vsk, mode); } =20 static void vsock_pending_work(struct work_struct *work) @@ -757,7 +916,10 @@ EXPORT_SYMBOL(vsock_bind_stream); static int vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - if (!vsk->transport || !vsk->transport->dgram_bind) { + const struct vsock_transport *transport; + + transport =3D vsock_core_get_transport(vsk); + if (!transport || !transport->dgram_bind) { int retval; spin_lock_bh(&vsock_dgram_table_lock); retval =3D vsock_bind_common(vsk, addr, vsock_dgram_bind_table, @@ -767,7 +929,7 @@ static int vsock_bind_dgram(struct vsock_sock *vsk, return retval; } =20 - return vsk->transport->dgram_bind(vsk, addr); + return transport->dgram_bind(vsk, addr); } =20 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) @@ -816,6 +978,7 @@ static struct sock *__vsock_create(struct net *net, unsigned short type, int kern) { + struct vsock_remote_info *remote_info; struct sock *sk; struct vsock_sock *psk; struct vsock_sock *vsk; @@ -835,7 +998,14 @@ static struct sock *__vsock_create(struct net *net, =20 vsk =3D vsock_sk(sk); vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + remote_info =3D kmalloc(sizeof(*remote_info), GFP_KERNEL); + if (!remote_info) { + sk_free(sk); + return NULL; + } + vsock_addr_init(&remote_info->addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + rcu_assign_pointer(vsk->remote_info, remote_info); =20 sk->sk_destruct =3D vsock_sk_destruct; sk->sk_backlog_rcv =3D vsock_queue_rcv_skb; @@ -882,6 +1052,7 @@ static bool sock_type_connectible(u16 type) static void __vsock_release(struct sock *sk, int level) { if (sk) { + const struct vsock_transport *transport; struct sock *pending; struct vsock_sock *vsk; =20 @@ -895,8 +1066,9 @@ static void __vsock_release(struct sock *sk, int level) */ lock_sock_nested(sk, level); =20 - if (vsk->transport) - vsk->transport->release(vsk); + transport =3D vsock_core_get_transport(vsk); + if (transport) + transport->release(vsk); else if (sock_type_connectible(sk->sk_type)) vsock_remove_sock(vsk); =20 @@ -926,8 +1098,6 @@ static void vsock_sk_destruct(struct sock *sk) * possibly register the address family with the kernel. */ vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - put_cred(vsk->owner); } =20 @@ -951,16 +1121,22 @@ EXPORT_SYMBOL_GPL(vsock_create_connected); =20 s64 vsock_stream_has_data(struct vsock_sock *vsk) { - return vsk->transport->stream_has_data(vsk); + const struct vsock_transport *transport; + + transport =3D vsock_core_get_transport(vsk); + + return transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); =20 s64 vsock_connectible_has_data(struct vsock_sock *vsk) { + const struct vsock_transport *transport; struct sock *sk =3D sk_vsock(vsk); =20 + transport =3D vsock_core_get_transport(vsk); if (sk->sk_type =3D=3D SOCK_SEQPACKET) - return vsk->transport->seqpacket_has_data(vsk); + return transport->seqpacket_has_data(vsk); else return vsock_stream_has_data(vsk); } @@ -968,7 +1144,10 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data); =20 s64 vsock_stream_has_space(struct vsock_sock *vsk) { - return vsk->transport->stream_has_space(vsk); + const struct vsock_transport *transport; + + transport =3D vsock_core_get_transport(vsk); + return transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); =20 @@ -1017,6 +1196,7 @@ static int vsock_getname(struct socket *sock, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *vm_addr; + struct vsock_remote_info *rcu_ptr; =20 sk =3D sock->sk; vsk =3D vsock_sk(sk); @@ -1025,11 +1205,17 @@ static int vsock_getname(struct socket *sock, lock_sock(sk); =20 if (peer) { + rcu_read_lock(); if (sock->state !=3D SS_CONNECTED) { err =3D -ENOTCONN; goto out; } - vm_addr =3D &vsk->remote_addr; + rcu_ptr =3D vsock_core_get_remote_info(vsk); + if (!rcu_ptr) { + err =3D -EINVAL; + goto out; + } + vm_addr =3D &rcu_ptr->addr; } else { vm_addr =3D &vsk->local_addr; } @@ -1049,6 +1235,8 @@ static int vsock_getname(struct socket *sock, err =3D sizeof(*vm_addr); =20 out: + if (peer) + rcu_read_unlock(); release_sock(sk); return err; } @@ -1153,7 +1341,7 @@ static __poll_t vsock_poll(struct file *file, struct = socket *sock, =20 lock_sock(sk); =20 - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 /* Listening sockets that have connections in their accept * queue can be read. @@ -1224,9 +1412,11 @@ static __poll_t vsock_poll(struct file *file, struct= socket *sock, =20 static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) { + const struct vsock_transport *transport; struct vsock_sock *vsk =3D vsock_sk(sk); =20 - return vsk->transport->read_skb(vsk, read_actor); + transport =3D vsock_core_get_transport(vsk); + return transport->read_skb(vsk, read_actor); } =20 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, @@ -1235,7 +1425,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, s= truct msghdr *msg, int err; struct sock *sk; struct vsock_sock *vsk; - struct sockaddr_vm *remote_addr; + struct sockaddr_vm stack_addr, *remote_addr; const struct vsock_transport *transport; =20 if (msg->msg_flags & MSG_OOB) @@ -1246,7 +1436,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, = struct msghdr *msg, sk =3D sock->sk; vsk =3D vsock_sk(sk); =20 - lock_sock(sk); + /* If auto-binding is required, acquire the slock to avoid potential + * race conditions. Otherwise, do not acquire the lock. + * + * We know that the first check of local_addr is racy (indicated by + * data_race()). By acquiring the lock and then subsequently checking + * again if local_addr is bound (inside vsock_auto_bind()), we can + * ensure there are no real data races. + * + * This technique is borrowed by inet_send_prepare(). + */ + if (data_race(!vsock_addr_bound(&vsk->local_addr))) { + lock_sock(sk); + err =3D vsock_auto_bind(vsk); + release_sock(sk); + if (err) + return err; + } =20 /* If the provided message contains an address, use that. Otherwise * fall back on the socket's remote handle (if it has been connected). @@ -1256,6 +1462,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, s= truct msghdr *msg, &remote_addr) =3D=3D 0) { transport =3D vsock_dgram_lookup_transport(remote_addr->svm_cid, remote_addr->svm_flags); + if (!transport) { err =3D -EINVAL; goto out; @@ -1286,18 +1493,39 @@ static int vsock_dgram_sendmsg(struct socket *sock,= struct msghdr *msg, goto out; } =20 - err =3D transport->dgram_enqueue(vsk, remote_addr, msg, len); + err =3D transport->dgram_enqueue(transport, vsk, remote_addr, msg, len); module_put(transport->module); } else if (sock->state =3D=3D SS_CONNECTED) { - remote_addr =3D &vsk->remote_addr; - transport =3D vsk->transport; + struct vsock_remote_info *remote_info; + const struct vsock_transport *transport; =20 - err =3D vsock_auto_bind(vsk); - if (err) + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + err =3D -EINVAL; + rcu_read_unlock(); goto out; + } =20 - if (remote_addr->svm_cid =3D=3D VMADDR_CID_ANY) + transport =3D remote_info->transport; + memcpy(&stack_addr, &remote_info->addr, sizeof(stack_addr)); + rcu_read_unlock(); + + remote_addr =3D &stack_addr; + + if (remote_addr->svm_cid =3D=3D VMADDR_CID_ANY) { remote_addr->svm_cid =3D transport->get_local_cid(); + lock_sock(sk_vsock(vsk)); + /* Even though the CID has changed, We do not have to + * look up the transport again because the local CID + * will never resolve to a different transport. + */ + err =3D vsock_set_remote_info(vsk, transport, remote_addr); + release_sock(sk_vsock(vsk)); + + if (err) + goto out; + } =20 /* XXX Should connect() or this function ensure remote_addr is * bound? @@ -1313,14 +1541,13 @@ static int vsock_dgram_sendmsg(struct socket *sock,= struct msghdr *msg, goto out; } =20 - err =3D transport->dgram_enqueue(vsk, remote_addr, msg, len); + err =3D transport->dgram_enqueue(transport, vsk, &stack_addr, msg, len); } else { err =3D -EINVAL; goto out; } =20 out: - release_sock(sk); return err; } =20 @@ -1331,18 +1558,22 @@ static int vsock_dgram_connect(struct socket *sock, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; =20 sk =3D sock->sk; vsk =3D vsock_sk(sk); =20 err =3D vsock_addr_cast(addr, addr_len, &remote_addr); if (err =3D=3D -EAFNOSUPPORT && remote_addr->svm_family =3D=3D AF_UNSPEC)= { + struct sockaddr_vm addr_any; + lock_sock(sk); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, - VMADDR_PORT_ANY); + vsock_addr_init(&addr_any, VMADDR_CID_ANY, VMADDR_PORT_ANY); + err =3D vsock_set_remote_info(vsk, vsock_core_get_transport(vsk), + &addr_any); sock->state =3D SS_UNCONNECTED; release_sock(sk); - return 0; + return err; } else if (err !=3D 0) return -EINVAL; =20 @@ -1352,14 +1583,13 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; =20 - memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); - - err =3D vsock_assign_transport(vsk, NULL); + err =3D vsock_assign_transport(vsk, NULL, remote_addr); if (err) goto out; =20 - if (!vsk->transport->dgram_allow(remote_addr->svm_cid, - remote_addr->svm_port)) { + transport =3D vsock_core_get_transport(vsk); + if (!transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { err =3D -EINVAL; goto out; } @@ -1406,7 +1636,9 @@ int vsock_dgram_recvmsg(struct socket *sock, struct m= sghdr *msg, if (flags & MSG_OOB || flags & MSG_ERRQUEUE) return -EOPNOTSUPP; =20 - transport =3D vsk->transport; + rcu_read_lock(); + transport =3D vsock_core_get_transport(vsk); + rcu_read_unlock(); =20 /* Retrieve the head sk_buff from the socket's receive queue. */ err =3D 0; @@ -1474,7 +1706,7 @@ static const struct proto_ops vsock_dgram_ops =3D { =20 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) { - const struct vsock_transport *transport =3D vsk->transport; + const struct vsock_transport *transport =3D vsock_core_get_transport(vsk); =20 if (!transport || !transport->cancel_pkt) return -EOPNOTSUPP; @@ -1511,6 +1743,7 @@ static int vsock_connect(struct socket *sock, struct = sockaddr *addr, struct sock *sk; struct vsock_sock *vsk; const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); @@ -1548,14 +1781,20 @@ static int vsock_connect(struct socket *sock, struc= t sockaddr *addr, } =20 /* Set the remote address that we are connecting to. */ - memcpy(&vsk->remote_addr, remote_addr, - sizeof(vsk->remote_addr)); - - err =3D vsock_assign_transport(vsk, NULL); + err =3D vsock_assign_transport(vsk, NULL, remote_addr); if (err) goto out; =20 - transport =3D vsk->transport; + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + err =3D -EINVAL; + rcu_read_unlock(); + goto out; + } + + transport =3D remote_info->transport; + rcu_read_unlock(); =20 /* The hypervisor and well-known contexts do not have socket * endpoints. @@ -1819,7 +2058,7 @@ static int vsock_connectible_setsockopt(struct socket= *sock, =20 lock_sock(sk); =20 - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: @@ -1957,7 +2196,7 @@ static int vsock_connectible_sendmsg(struct socket *s= ock, struct msghdr *msg, =20 lock_sock(sk); =20 - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 /* Callers should not provide a destination with connection oriented * sockets. @@ -1980,7 +2219,7 @@ static int vsock_connectible_sendmsg(struct socket *s= ock, struct msghdr *msg, goto out; } =20 - if (!vsock_addr_bound(&vsk->remote_addr)) { + if (!vsock_remote_addr_bound(vsk)) { err =3D -EDESTADDRREQ; goto out; } @@ -2101,7 +2340,7 @@ static int vsock_connectible_wait_data(struct sock *s= k, =20 vsk =3D vsock_sk(sk); err =3D 0; - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 while (1) { prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE); @@ -2169,7 +2408,7 @@ static int __vsock_stream_recvmsg(struct sock *sk, st= ruct msghdr *msg, DEFINE_WAIT(wait); =20 vsk =3D vsock_sk(sk); - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 /* We must not copy less than target bytes into the user's buffer * before returning successfully, so we wait for the consume queue to @@ -2245,7 +2484,7 @@ static int __vsock_seqpacket_recvmsg(struct sock *sk,= struct msghdr *msg, DEFINE_WAIT(wait); =20 vsk =3D vsock_sk(sk); - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 timeout =3D sock_rcvtimeo(sk, flags & MSG_DONTWAIT); =20 @@ -2302,7 +2541,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct= msghdr *msg, size_t len, =20 lock_sock(sk); =20 - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 if (!transport || sk->sk_state !=3D TCP_ESTABLISHED) { /* Recvmsg is supposed to return 0 if a peer performs an @@ -2369,7 +2608,7 @@ static int vsock_set_rcvlowat(struct sock *sk, int va= l) if (val > vsk->buffer_size) return -EINVAL; =20 - transport =3D vsk->transport; + transport =3D vsock_core_get_transport(vsk); =20 if (transport && transport->set_rcvlowat) return transport->set_rcvlowat(vsk, val); @@ -2459,7 +2698,10 @@ static int vsock_create(struct net *net, struct sock= et *sock, vsk =3D vsock_sk(sk); =20 if (sock->type =3D=3D SOCK_DGRAM) { - ret =3D vsock_assign_transport(vsk, NULL); + struct sockaddr_vm remote_addr; + + vsock_addr_init(&remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + ret =3D vsock_assign_transport(vsk, NULL, &remote_addr); if (ret < 0) { sock_put(sk); return ret; @@ -2581,7 +2823,18 @@ static void __exit vsock_exit(void) =20 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *= vsk) { - return vsk->transport; + const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return NULL; + } + transport =3D remote_info->transport; + rcu_read_unlock(); + return transport; } EXPORT_SYMBOL_GPL(vsock_core_get_transport); =20 diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c index a2823b1c5e28..f843bae86b32 100644 --- a/net/vmw_vsock/diag.c +++ b/net/vmw_vsock/diag.c @@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff = *skb, u32 portid, u32 seq, u32 flags) { struct vsock_sock *vsk =3D vsock_sk(sk); + struct sockaddr_vm remote_addr; struct vsock_diag_msg *rep; struct nlmsghdr *nlh; + int err; + + err =3D vsock_remote_addr_copy(vsk, &remote_addr); + if (err < 0) + return err; =20 nlh =3D nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), flags); @@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *= skb, rep->vdiag_shutdown =3D sk->sk_shutdown; rep->vdiag_src_cid =3D vsk->local_addr.svm_cid; rep->vdiag_src_port =3D vsk->local_addr.svm_port; - rep->vdiag_dst_cid =3D vsk->remote_addr.svm_cid; - rep->vdiag_dst_port =3D vsk->remote_addr.svm_port; + rep->vdiag_dst_cid =3D remote_addr.svm_cid; + rep->vdiag_dst_port =3D remote_addr.svm_port; rep->vdiag_ino =3D sock_i_ino(sk); =20 sock_diag_save_cookie(sk, rep->vdiag_cookie); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transp= ort.c index c00bc5da769a..84e8c64b3365 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -323,6 +323,8 @@ static void hvs_open_connection(struct vmbus_channel *c= han) goto out; =20 if (conn_from_host) { + struct sockaddr_vm remote_addr; + if (sk->sk_ack_backlog >=3D sk->sk_max_ack_backlog) goto out; =20 @@ -336,10 +338,9 @@ static void hvs_open_connection(struct vmbus_channel *= chan) hvs_addr_init(&vnew->local_addr, if_type); =20 /* Remote peer is always the host */ - vsock_addr_init(&vnew->remote_addr, - VMADDR_CID_HOST, VMADDR_PORT_ANY); - vnew->remote_addr.svm_port =3D get_port_by_srv_id(if_instance); - ret =3D vsock_assign_transport(vnew, vsock_sk(sk)); + vsock_addr_init(&remote_addr, VMADDR_CID_HOST, get_port_by_srv_id(if_ins= tance)); + + ret =3D vsock_assign_transport(vnew, vsock_sk(sk), &remote_addr); /* Transport assigned (looking at remote_addr) must be the * same where we received the request. */ @@ -459,13 +460,18 @@ static int hvs_connect(struct vsock_sock *vsk) { union hvs_service_id vm, host; struct hvsock *h =3D vsk->trans; + int err; =20 vm.srv_id =3D srv_id_template; vm.svm_port =3D vsk->local_addr.svm_port; h->vm_srv_id =3D vm.srv_id; =20 host.srv_id =3D srv_id_template; - host.svm_port =3D vsk->remote_addr.svm_port; + + err =3D vsock_remote_addr_port(vsk, &host.svm_port); + if (err < 0) + return err; + h->host_srv_id =3D host.srv_id; =20 return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id); @@ -566,7 +572,8 @@ static int hvs_dgram_get_length(struct sk_buff *skb, si= ze_t *len) return -EOPNOTSUPP; } =20 -static int hvs_dgram_enqueue(struct vsock_sock *vsk, +static int hvs_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote, struct msghdr *msg, size_t dgram_len) { @@ -866,7 +873,13 @@ static struct vsock_transport hvs_transport =3D { =20 static bool hvs_check_transport(struct vsock_sock *vsk) { - return vsk->transport =3D=3D &hvs_transport; + bool ret; + + rcu_read_lock(); + ret =3D vsock_core_get_transport(vsk) =3D=3D &hvs_transport; + rcu_read_unlock(); + + return ret; } =20 static int hvs_probe(struct hv_device *hdev, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio= _transport_common.c index ab4af21c4f3f..09d35c488902 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -258,8 +258,9 @@ static int virtio_transport_send_pkt_info(struct vsock_= sock *vsk, src_cid =3D t_ops->transport.get_local_cid(); src_port =3D vsk->local_addr.svm_port; if (!info->remote_cid) { - dst_cid =3D vsk->remote_addr.svm_cid; - dst_port =3D vsk->remote_addr.svm_port; + ret =3D vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port); + if (ret < 0) + return ret; } else { dst_cid =3D info->remote_cid; dst_port =3D info->remote_port; @@ -877,12 +878,14 @@ int virtio_transport_shutdown(struct vsock_sock *vsk,= int mode) EXPORT_SYMBOL_GPL(virtio_transport_shutdown); =20 int -virtio_transport_dgram_enqueue(struct vsock_sock *vsk, +virtio_transport_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, size_t dgram_len) { - const struct virtio_transport *t_ops; + const struct virtio_transport *t_ops =3D + (const struct virtio_transport *)transport; struct virtio_vsock_pkt_info info =3D { .op =3D VIRTIO_VSOCK_OP_RW, .msg =3D msg, @@ -896,7 +899,6 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk, if (dgram_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) return -EMSGSIZE; =20 - t_ops =3D virtio_transport_get_ops(vsk); src_cid =3D t_ops->transport.get_local_cid(); src_port =3D vsk->local_addr.svm_port; =20 @@ -1120,7 +1122,9 @@ virtio_transport_recv_connecting(struct sock *sk, case VIRTIO_VSOCK_OP_RESPONSE: sk->sk_state =3D TCP_ESTABLISHED; sk->sk_socket->state =3D SS_CONNECTED; - vsock_insert_connected(vsk); + err =3D vsock_insert_connected(vsk); + if (err) + goto destroy; sk->sk_state_change(sk); break; case VIRTIO_VSOCK_OP_INVALID: @@ -1326,6 +1330,7 @@ virtio_transport_recv_listen(struct sock *sk, struct = sk_buff *skb, struct virtio_vsock_hdr *hdr =3D virtio_vsock_hdr(skb); struct vsock_sock *vsk =3D vsock_sk(sk); struct vsock_sock *vchild; + struct sockaddr_vm child_remote; struct sock *child; int ret; =20 @@ -1354,14 +1359,13 @@ virtio_transport_recv_listen(struct sock *sk, struc= t sk_buff *skb, vchild =3D vsock_sk(child); vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid), le32_to_cpu(hdr->dst_port)); - vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid), + vsock_addr_init(&child_remote, le64_to_cpu(hdr->src_cid), le32_to_cpu(hdr->src_port)); - - ret =3D vsock_assign_transport(vchild, vsk); + ret =3D vsock_assign_transport(vchild, vsk, &child_remote); /* Transport assigned (looking at remote_addr) must be the same * where we received the request. */ - if (ret || vchild->transport !=3D &t->transport) { + if (ret || vsock_core_get_transport(vchild) !=3D &t->transport) { release_sock(child); virtio_transport_reset_no_sock(t, skb); sock_put(child); @@ -1371,7 +1375,13 @@ virtio_transport_recv_listen(struct sock *sk, struct= sk_buff *skb, if (virtio_transport_space_update(child, skb)) child->sk_write_space(child); =20 - vsock_insert_connected(vchild); + ret =3D vsock_insert_connected(vchild); + if (ret) { + release_sock(child); + virtio_transport_reset_no_sock(t, skb); + sock_put(child); + return ret; + } vsock_enqueue_accept(sk, child); virtio_transport_send_response(vchild, skb); =20 diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index b6a51afb74b8..b9ba6209e8fc 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk, u16 proto, struct vmci_handle handle) { + struct sockaddr_vm addr_stack; + struct sockaddr_vm *remote_addr =3D &addr_stack; struct vsock_sock *vsk; + int err; =20 vsk =3D vsock_sk(sk); =20 if (!vsock_addr_bound(&vsk->local_addr)) return -EINVAL; =20 - if (!vsock_addr_bound(&vsk->remote_addr)) + if (!vsock_remote_addr_bound(vsk)) return -EINVAL; =20 + err =3D vsock_remote_addr_copy(vsk, remote_addr); + if (err < 0) + return err; + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, - &vsk->remote_addr, + remote_addr, type, size, mode, wait, proto, handle); } @@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk, struct sockaddr_vm *dst_ptr; struct sockaddr_vm dst; struct vsock_sock *vsk; + int err; =20 if (pkt->type =3D=3D VMCI_TRANSPORT_PACKET_TYPE_RST) return 0; @@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk, if (!vsock_addr_bound(&vsk->local_addr)) return -EINVAL; =20 - if (vsock_addr_bound(&vsk->remote_addr)) { - dst_ptr =3D &vsk->remote_addr; + if (vsock_remote_addr_bound(vsk)) { + err =3D vsock_remote_addr_copy(vsk, &dst); + if (err < 0) + return err; } else { vsock_addr_init(&dst, pkt->dg.src.context, pkt->src_port); - dst_ptr =3D &dst; } + dst_ptr =3D &dst; + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr, VMCI_TRANSPORT_PACKET_TYPE_RST, 0, 0, NULL, VSOCK_PROTO_INVALID, @@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending( =20 list_for_each_entry(vpending, &vlistener->pending_links, pending_links) { - if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && + if (vsock_remote_addr_equals(vpending, &src) && pkt->dst_port =3D=3D vpending->local_addr.svm_port) { pending =3D sk_vsock(vpending); sock_hold(pending); @@ -940,6 +951,7 @@ static void vmci_transport_recv_pkt_work(struct work_st= ruct *work) static int vmci_transport_recv_listen(struct sock *sk, struct vmci_transport_packet *pkt) { + struct sockaddr_vm remote_addr; struct sock *pending; struct vsock_sock *vpending; int err; @@ -1015,10 +1027,10 @@ static int vmci_transport_recv_listen(struct sock *= sk, =20 vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context, pkt->dst_port); - vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, - pkt->src_port); =20 - err =3D vsock_assign_transport(vpending, vsock_sk(sk)); + vsock_addr_init(&remote_addr, pkt->dg.src.context, pkt->src_port); + + err =3D vsock_assign_transport(vpending, vsock_sk(sk), &remote_addr); /* Transport assigned (looking at remote_addr) must be the same * where we received the request. */ @@ -1133,6 +1145,7 @@ vmci_transport_recv_connecting_server(struct sock *li= stener, { struct vsock_sock *vpending; struct vmci_handle handle; + unsigned int vpending_remote_cid; struct vmci_qp *qpair; bool is_local; u32 flags; @@ -1189,8 +1202,13 @@ vmci_transport_recv_connecting_server(struct sock *l= istener, /* vpending->local_addr always has a context id so we do not need to * worry about VMADDR_CID_ANY in this case. */ - is_local =3D - vpending->remote_addr.svm_cid =3D=3D vpending->local_addr.svm_cid; + err =3D vsock_remote_addr_cid(vpending, &vpending_remote_cid); + if (err < 0) { + skerr =3D EPROTO; + goto destroy; + } + + is_local =3D vpending_remote_cid =3D=3D vpending->local_addr.svm_cid; flags =3D VMCI_QPFLAG_ATTACH_ONLY; flags |=3D is_local ? VMCI_QPFLAG_LOCAL : 0; =20 @@ -1203,7 +1221,7 @@ vmci_transport_recv_connecting_server(struct sock *li= stener, flags, vmci_transport_is_trusted( vpending, - vpending->remote_addr.svm_cid)); + vpending_remote_cid)); if (err < 0) { vmci_transport_send_reset(pending, pkt); skerr =3D -err; @@ -1277,6 +1295,8 @@ static int vmci_transport_recv_connecting_client(struct sock *sk, struct vmci_transport_packet *pkt) { + struct vsock_remote_info *remote_info; + struct sockaddr_vm *remote_addr; struct vsock_sock *vsk; int err; int skerr; @@ -1306,9 +1326,20 @@ vmci_transport_recv_connecting_client(struct sock *s= k, break; case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + rcu_read_lock(); + remote_info =3D vsock_core_get_remote_info(vsk); + if (!remote_info) { + skerr =3D EPROTO; + err =3D -EINVAL; + rcu_read_unlock(); + goto destroy; + } + + remote_addr =3D &remote_info->addr; + if (pkt->u.size =3D=3D 0 - || pkt->dg.src.context !=3D vsk->remote_addr.svm_cid - || pkt->src_port !=3D vsk->remote_addr.svm_port + || pkt->dg.src.context !=3D remote_addr->svm_cid + || pkt->src_port !=3D remote_addr->svm_port || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle) || vmci_trans(vsk)->qpair || vmci_trans(vsk)->produce_size !=3D 0 @@ -1316,9 +1347,10 @@ vmci_transport_recv_connecting_client(struct sock *s= k, || vmci_trans(vsk)->detach_sub_id !=3D VMCI_INVALID_ID) { skerr =3D EPROTO; err =3D -EINVAL; - + rcu_read_unlock(); goto destroy; } + rcu_read_unlock(); =20 err =3D vmci_transport_recv_connecting_client_negotiate(sk, pkt); if (err) { @@ -1379,6 +1411,7 @@ static int vmci_transport_recv_connecting_client_nego= tiate( int err; struct vsock_sock *vsk; struct vmci_handle handle; + unsigned int remote_cid; struct vmci_qp *qpair; u32 detach_sub_id; bool is_local; @@ -1449,19 +1482,23 @@ static int vmci_transport_recv_connecting_client_ne= gotiate( =20 /* Make VMCI select the handle for us. */ handle =3D VMCI_INVALID_HANDLE; - is_local =3D vsk->remote_addr.svm_cid =3D=3D vsk->local_addr.svm_cid; + + err =3D vsock_remote_addr_cid(vsk, &remote_cid); + if (err < 0) + goto destroy; + + is_local =3D remote_cid =3D=3D vsk->local_addr.svm_cid; flags =3D is_local ? VMCI_QPFLAG_LOCAL : 0; =20 err =3D vmci_transport_queue_pair_alloc(&qpair, &handle, pkt->u.size, pkt->u.size, - vsk->remote_addr.svm_cid, + remote_cid, flags, vmci_transport_is_trusted( vsk, - vsk-> - remote_addr.svm_cid)); + remote_cid)); if (err < 0) goto destroy; =20 @@ -1692,6 +1729,7 @@ static int vmci_transport_dgram_bind(struct vsock_soc= k *vsk, } =20 static int vmci_transport_dgram_enqueue( + const struct vsock_transport *transport, struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, @@ -2052,7 +2090,13 @@ static struct vsock_transport vmci_transport =3D { =20 static bool vmci_check_transport(struct vsock_sock *vsk) { - return vsk->transport =3D=3D &vmci_transport; + bool retval; + + rcu_read_lock(); + retval =3D vsock_core_get_transport(vsk) =3D=3D &vmci_transport; + rcu_read_unlock(); + + return retval; } =20 static void vmci_vsock_transport_cb(bool is_host) diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c index a3c97546ab84..4d811c9cdf6e 100644 --- a/net/vmw_vsock/vsock_bpf.c +++ b/net/vmw_vsock/vsock_bpf.c @@ -148,6 +148,7 @@ static void vsock_bpf_check_needs_rebuild(struct proto = *ops) =20 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool r= estore) { + const struct vsock_transport *transport; struct vsock_sock *vsk; =20 if (restore) { @@ -157,10 +158,15 @@ int vsock_bpf_update_proto(struct sock *sk, struct sk= _psock *psock, bool restore } =20 vsk =3D vsock_sk(sk); - if (!vsk->transport) + + rcu_read_lock(); + transport =3D vsock_core_get_transport(vsk); + rcu_read_unlock(); + + if (!transport) return -ENODEV; =20 - if (!vsk->transport->read_skb) + if (!transport->read_skb) return -EOPNOTSUPP; =20 vsock_bpf_check_needs_rebuild(psock->sk_proto); --=20 2.30.2