[PATCH v4 2/2] net: read header_ops callbacks with READ_ONCE()

Kota Toda posted 2 patches 3 weeks, 6 days ago
[PATCH v4 2/2] net: read header_ops callbacks with READ_ONCE()
Posted by Kota Toda 3 weeks, 6 days ago
Bonding now updates its header_ops callbacks at runtime, so lockless
readers can observe concurrent callback updates.

This patch loads header_ops callbacks with READ_ONCE() and 
call the loaded function pointer, instead of 
re-reading through dev->header_ops.

Signed-off-by: Kota Toda <kota.toda@gmo-cybersecurity.com>
Co-developed-by: Yuki Koike <yuki.koike@gmo-cybersecurity.com>
Signed-off-by: Yuki Koike <yuki.koike@gmo-cybersecurity.com>
---
 include/linux/netdevice.h | 41 +++++++++++++++++++++++++++------------
 include/net/cfg802154.h   |  2 +-
 net/core/neighbour.c      |  6 +++---
 net/ipv4/arp.c            |  2 +-
 net/ipv6/ndisc.c          |  2 +-
 5 files changed, 35 insertions(+), 18 deletions(-)

diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index 77a99c8ab..79fb0864a 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -3150,35 +3150,50 @@ static inline int dev_hard_header(struct sk_buff *skb, struct net_device *dev,
 				  const void *daddr, const void *saddr,
 				  unsigned int len)
 {
-	if (!dev->header_ops || !dev->header_ops->create)
-		return 0;
+	int (*create)(struct sk_buff *skb, struct net_device *dev,
+		      unsigned short type, const void *daddr,
+		      const void *saddr, unsigned int len);
 
-	return dev->header_ops->create(skb, dev, type, daddr, saddr, len);
+	if (!dev->header_ops)
+		return 0;
+	create = READ_ONCE(dev->header_ops->create);
+	if (!create)
+		return 0;
+	return create(skb, dev, type, daddr, saddr, len);
 }
 
 static inline int dev_parse_header(const struct sk_buff *skb,
 				   unsigned char *haddr)
 {
+	int (*parse)(const struct sk_buff *skb, unsigned char *haddr);
 	const struct net_device *dev = skb->dev;
 
-	if (!dev->header_ops || !dev->header_ops->parse)
+	if (!dev->header_ops)
 		return 0;
-	return dev->header_ops->parse(skb, haddr);
+	parse = READ_ONCE(dev->header_ops->parse);
+	if (!parse)
+		return 0;
+	return parse(skb, haddr);
 }
 
 static inline __be16 dev_parse_header_protocol(const struct sk_buff *skb)
 {
+	__be16	(*parse_protocol)(const struct sk_buff *skb);
 	const struct net_device *dev = skb->dev;
 
-	if (!dev->header_ops || !dev->header_ops->parse_protocol)
+	if (!dev->header_ops)
+		return 0;
+	parse_protocol = READ_ONCE(dev->header_ops->parse_protocol);
+	if (!parse_protocol)
 		return 0;
-	return dev->header_ops->parse_protocol(skb);
+	return parse_protocol(skb);
 }
 
 /* ll_header must have at least hard_header_len allocated */
 static inline bool dev_validate_header(const struct net_device *dev,
 				       char *ll_header, int len)
 {
+	bool	(*validate)(const char *ll_header, unsigned int len);
 	if (likely(len >= dev->hard_header_len))
 		return true;
 	if (len < dev->min_header_len)
@@ -3189,15 +3204,17 @@ static inline bool dev_validate_header(const struct net_device *dev,
 		return true;
 	}
 
-	if (dev->header_ops && dev->header_ops->validate)
-		return dev->header_ops->validate(ll_header, len);
-
-	return false;
+	if (!dev->header_ops)
+		return false;
+	validate = READ_ONCE(dev->header_ops->validate);
+	if (!validate)
+		return false;
+	return validate(ll_header, len);
 }
 
 static inline bool dev_has_header(const struct net_device *dev)
 {
-	return dev->header_ops && dev->header_ops->create;
+	return dev->header_ops && READ_ONCE(dev->header_ops->create);
 }
 
 /*
diff --git a/include/net/cfg802154.h b/include/net/cfg802154.h
index 76d2cd2e2..dec638763 100644
--- a/include/net/cfg802154.h
+++ b/include/net/cfg802154.h
@@ -522,7 +522,7 @@ wpan_dev_hard_header(struct sk_buff *skb, struct net_device *dev,
 {
 	struct wpan_dev *wpan_dev = dev->ieee802154_ptr;
 
-	return wpan_dev->header_ops->create(skb, dev, daddr, saddr, len);
+	return READ_ONCE(wpan_dev->header_ops->create)(skb, dev, daddr, saddr, len);
 }
 #endif
 
diff --git a/net/core/neighbour.c b/net/core/neighbour.c
index 96786016d..ff948e35e 100644
--- a/net/core/neighbour.c
+++ b/net/core/neighbour.c
@@ -1270,7 +1270,7 @@ static void neigh_update_hhs(struct neighbour *neigh)
 		= NULL;
 
 	if (neigh->dev->header_ops)
-		update = neigh->dev->header_ops->cache_update;
+		update = READ_ONCE(neigh->dev->header_ops->cache_update);
 
 	if (update) {
 		hh = &neigh->hh;
@@ -1540,7 +1540,7 @@ static void neigh_hh_init(struct neighbour *n)
 	 * hh_cache entry.
 	 */
 	if (!hh->hh_len)
-		dev->header_ops->cache(n, hh, prot);
+		READ_ONCE(dev->header_ops->cache)(n, hh, prot);
 
 	write_unlock_bh(&n->lock);
 }
@@ -1556,7 +1556,7 @@ int neigh_resolve_output(struct neighbour *neigh, struct sk_buff *skb)
 		struct net_device *dev = neigh->dev;
 		unsigned int seq;
 
-		if (dev->header_ops->cache && !READ_ONCE(neigh->hh.hh_len))
+		if (READ_ONCE(dev->header_ops->cache) && !READ_ONCE(neigh->hh.hh_len))
 			neigh_hh_init(neigh);
 
 		do {
diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c
index 7822b2144..421bea6eb 100644
--- a/net/ipv4/arp.c
+++ b/net/ipv4/arp.c
@@ -278,7 +278,7 @@ static int arp_constructor(struct neighbour *neigh)
 			memcpy(neigh->ha, dev->broadcast, dev->addr_len);
 		}
 
-		if (dev->header_ops->cache)
+		if (READ_ONCE(dev->header_ops->cache))
 			neigh->ops = &arp_hh_ops;
 		else
 			neigh->ops = &arp_generic_ops;
diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c
index d961e6c2d..d81f509ec 100644
--- a/net/ipv6/ndisc.c
+++ b/net/ipv6/ndisc.c
@@ -361,7 +361,7 @@ static int ndisc_constructor(struct neighbour *neigh)
 			neigh->nud_state = NUD_NOARP;
 			memcpy(neigh->ha, dev->broadcast, dev->addr_len);
 		}
-		if (dev->header_ops->cache)
+		if (READ_ONCE(dev->header_ops->cache))
 			neigh->ops = &ndisc_hh_ops;
 		else
 			neigh->ops = &ndisc_generic_ops;
-- 
2.53.0