From: Geliang Tang <tanggeliang@kylinos.cn>
To add MPTCP support in "NVMe over TCP", the target side needs to pass
IPPROTO_MPTCP to sock_create() instead of IPPROTO_TCP to create an MPTCP
socket. Additionally, the setsockopt operations for this socket need to
be switched to a set of MPTCP-specific functions.
This patch defines the nvmet_tcp_proto structure, which contains the
protocol of the socket and a set of function pointers for these socket
operations. A "proto" field is also added to struct nvmet_tcp_port.
A TCP-specific version of struct nvmet_tcp_proto is defined. In
nvmet_tcp_add_port(), port->proto is set to nvmet_tcp_proto based on
whether trtype is TCP. All locations that previously called TCP setsockopt
functions are updated to call the corresponding function pointers in the
nvmet_tcp_proto structure.
This new nvmet_fabrics_ops is selected in nvmet_tcp_done_recv_pdu() based
on the protocol type.
RCU protection is added when accessing queue->port in the I/O path
(nvmet_tcp_alloc_cmd, nvmet_tcp_done_recv_pdu, nvmet_tcp_set_queue_sock)
to prevent use-after-free when a port is removed while asynchronous
operations (e.g., TLS handshake) are pending. The port structure is
released using kfree_rcu() in nvmet_tcp_remove_port(), and queue->port is
assigned using rcu_assign_pointer() in nvmet_tcp_alloc_queue().
Cc: Hannes Reinecke <hare@suse.de>
Co-developed-by: zhenwei pi <zhenwei.pi@linux.dev>
Signed-off-by: zhenwei pi <zhenwei.pi@linux.dev>
Co-developed-by: Hui Zhu <zhuhui@kylinos.cn>
Signed-off-by: Hui Zhu <zhuhui@kylinos.cn>
Co-developed-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Gang Yan <yangang@kylinos.cn>
Signed-off-by: Geliang Tang <tanggeliang@kylinos.cn>
---
drivers/nvme/target/tcp.c | 66 ++++++++++++++++++++++++++++++++-------
1 file changed, 55 insertions(+), 11 deletions(-)
diff --git a/drivers/nvme/target/tcp.c b/drivers/nvme/target/tcp.c
index acc71a26733f..d8d3d97de8ed 100644
--- a/drivers/nvme/target/tcp.c
+++ b/drivers/nvme/target/tcp.c
@@ -18,6 +18,7 @@
#include <net/handshake.h>
#include <linux/inet.h>
#include <linux/llist.h>
+#include <linux/rcupdate.h>
#include <trace/events/sock.h>
#include "nvmet.h"
@@ -198,12 +199,24 @@ struct nvmet_tcp_queue {
void (*write_space)(struct sock *);
};
+struct nvmet_tcp_proto {
+ int protocol;
+ void (*set_reuseaddr)(struct sock *sk);
+ void (*set_nodelay)(struct sock *sk);
+ void (*set_priority)(struct sock *sk, u32 priority);
+ void (*no_linger)(struct sock *sk);
+ void (*set_tos)(struct sock *sk, int val);
+ const struct nvmet_fabrics_ops *ops;
+};
+
struct nvmet_tcp_port {
+ struct rcu_head rcu;
struct socket *sock;
struct work_struct accept_work;
struct nvmet_port *nport;
struct sockaddr_storage addr;
void (*data_ready)(struct sock *);
+ const struct nvmet_tcp_proto *proto;
};
static DEFINE_IDA(nvmet_tcp_queue_ida);
@@ -1027,6 +1040,7 @@ static int nvmet_tcp_done_recv_pdu(struct nvmet_tcp_queue *queue)
{
struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr;
struct nvme_command *nvme_cmd = &queue->pdu.cmd.cmd;
+ const struct nvmet_fabrics_ops *ops;
struct nvmet_req *req;
int ret;
@@ -1067,7 +1081,10 @@ static int nvmet_tcp_done_recv_pdu(struct nvmet_tcp_queue *queue)
req = &queue->cmd->req;
memcpy(req->cmd, nvme_cmd, sizeof(*nvme_cmd));
- if (unlikely(!nvmet_req_init(req, &queue->nvme_sq, &nvmet_tcp_ops))) {
+ rcu_read_lock();
+ ops = rcu_dereference(queue->port)->proto->ops;
+ rcu_read_unlock();
+ if (unlikely(!nvmet_req_init(req, &queue->nvme_sq, ops))) {
pr_err("failed cmd %p id %d opcode %d, data_len: %d, status: %04x\n",
req->cmd, req->cmd->common.command_id,
req->cmd->common.opcode,
@@ -1686,6 +1703,7 @@ static int nvmet_tcp_set_queue_sock(struct nvmet_tcp_queue *queue)
{
struct socket *sock = queue->sock;
struct inet_sock *inet = inet_sk(sock->sk);
+ const struct nvmet_tcp_proto *proto;
int ret;
ret = kernel_getsockname(sock,
@@ -1698,19 +1716,23 @@ static int nvmet_tcp_set_queue_sock(struct nvmet_tcp_queue *queue)
if (ret < 0)
return ret;
+ rcu_read_lock();
+ proto = rcu_dereference(queue->port)->proto;
+ rcu_read_unlock();
+
/*
* Cleanup whatever is sitting in the TCP transmit queue on socket
* close. This is done to prevent stale data from being sent should
* the network connection be restored before TCP times out.
*/
- sock_no_linger(sock->sk);
+ proto->no_linger(sock->sk);
if (so_priority > 0)
- sock_set_priority(sock->sk, so_priority);
+ proto->set_priority(sock->sk, so_priority);
/* Set socket type of service */
if (inet->rcv_tos > 0)
- ip_sock_set_tos(sock->sk, inet->rcv_tos);
+ proto->set_tos(sock->sk, inet->rcv_tos);
ret = 0;
write_lock_bh(&sock->sk->sk_callback_lock);
@@ -2030,6 +2052,16 @@ static void nvmet_tcp_listen_data_ready(struct sock *sk)
read_unlock_bh(&sk->sk_callback_lock);
}
+static const struct nvmet_tcp_proto nvmet_tcp_proto = {
+ .protocol = IPPROTO_TCP,
+ .set_reuseaddr = sock_set_reuseaddr,
+ .set_nodelay = tcp_sock_set_nodelay,
+ .set_priority = sock_set_priority,
+ .no_linger = sock_no_linger,
+ .set_tos = ip_sock_set_tos,
+ .ops = &nvmet_tcp_ops,
+};
+
static int nvmet_tcp_add_port(struct nvmet_port *nport)
{
struct nvmet_tcp_port *port;
@@ -2054,6 +2086,13 @@ static int nvmet_tcp_add_port(struct nvmet_port *nport)
goto err_port;
}
+ if (nport->disc_addr.trtype == NVMF_TRTYPE_TCP) {
+ port->proto = &nvmet_tcp_proto;
+ } else {
+ ret = -EINVAL;
+ goto err_port;
+ }
+
ret = inet_pton_with_scope(&init_net, af, nport->disc_addr.traddr,
nport->disc_addr.trsvcid, &port->addr);
if (ret) {
@@ -2068,7 +2107,7 @@ static int nvmet_tcp_add_port(struct nvmet_port *nport)
port->nport->inline_data_size = NVMET_TCP_DEF_INLINE_DATA_SIZE;
ret = sock_create(port->addr.ss_family, SOCK_STREAM,
- IPPROTO_TCP, &port->sock);
+ port->proto->protocol, &port->sock);
if (ret) {
pr_err("failed to create a socket\n");
goto err_port;
@@ -2077,10 +2116,10 @@ static int nvmet_tcp_add_port(struct nvmet_port *nport)
port->sock->sk->sk_user_data = port;
port->data_ready = port->sock->sk->sk_data_ready;
port->sock->sk->sk_data_ready = nvmet_tcp_listen_data_ready;
- sock_set_reuseaddr(port->sock->sk);
- tcp_sock_set_nodelay(port->sock->sk);
+ port->proto->set_reuseaddr(port->sock->sk);
+ port->proto->set_nodelay(port->sock->sk);
if (so_priority > 0)
- sock_set_priority(port->sock->sk, so_priority);
+ port->proto->set_priority(port->sock->sk, so_priority);
ret = kernel_bind(port->sock, (struct sockaddr_unsized *)&port->addr,
sizeof(port->addr));
@@ -2111,11 +2150,16 @@ static int nvmet_tcp_add_port(struct nvmet_port *nport)
static void nvmet_tcp_destroy_port_queues(struct nvmet_tcp_port *port)
{
struct nvmet_tcp_queue *queue;
+ struct nvmet_tcp_port *qport;
mutex_lock(&nvmet_tcp_queue_mutex);
- list_for_each_entry(queue, &nvmet_tcp_queue_list, queue_list)
- if (queue->port == port)
+ list_for_each_entry(queue, &nvmet_tcp_queue_list, queue_list) {
+ rcu_read_lock();
+ qport = rcu_dereference(queue->port);
+ rcu_read_unlock();
+ if (qport == port)
kernel_sock_shutdown(queue->sock, SHUT_RDWR);
+ }
mutex_unlock(&nvmet_tcp_queue_mutex);
}
@@ -2135,7 +2179,7 @@ static void nvmet_tcp_remove_port(struct nvmet_port *nport)
nvmet_tcp_destroy_port_queues(port);
sock_release(port->sock);
- kfree(port);
+ kfree_rcu(port, rcu);
}
static void nvmet_tcp_delete_ctrl(struct nvmet_ctrl *ctrl)
--
2.51.0
© 2016 - 2026 Red Hat, Inc.