[RFC PATCH 4/4] mm/mempolicy: enhance weighted interleave with socket-aware locality

Rakie Kim posted 4 patches 3 weeks ago
[RFC PATCH 4/4] mm/mempolicy: enhance weighted interleave with socket-aware locality
Posted by Rakie Kim 3 weeks ago
Flat weighted interleave applies one global weight vector regardless of
where a task runs. On multi-socket systems this ignores inter-socket
interconnect costs and can steer allocations to remote sockets even when
local capacity exists, degrading effective bandwidth and increasing
latency.

Consider a dual-socket system:

          node0             node1
        +-------+         +-------+
        | CPU0  |---------| CPU1  |
        +-------+         +-------+
        | DRAM0 |         | DRAM1 |
        +---+---+         +---+---+
            |                 |
        +---+---+         +---+---+
        | CXL0  |         | CXL1  |
        +-------+         +-------+
          node2             node3

Local device capabilities (GB/s) versus cross-socket effective bandwidth:

         0     1     2     3
     0  300   150   100    50
     1  150   300    50   100

A reasonable global weight vector reflecting device capabilities is:

     node0=3 node1=3 node2=1 node3=1

However, applying it flat to all sources yields the effective map:

         0     1     2     3
     0   3     3     1     1
     1   3     3     1     1

This does not account for the interconnect penalty (e.g., node0->node1
drops 300->150, node0->node3 drops 100->50) and thus permits cross-socket
allocations that underutilize local bandwidth.

This patch makes weighted interleave socket-aware. Before weighting is
applied, the candidate nodes are restricted to the current socket; only
if no eligible local nodes remain does the policy fall back to the wider
set. The resulting effective map becomes:

         0     1     2     3
     0   3     0     1     0
     1   0     3     0     1

Now tasks running on node0 prefer DRAM0(3) and CXL0(1), while tasks on
node1 prefer DRAM1(3) and CXL1(1). This aligns allocation with actual
effective bandwidth, preserves NUMA locality, and reduces cross-socket
traffic.

Signed-off-by: Rakie Kim <rakie.kim@sk.com>
---
 mm/mempolicy.c | 94 +++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 90 insertions(+), 4 deletions(-)

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index a3f0fde6c626..541853ac08bc 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -117,6 +117,7 @@
 #include <asm/tlb.h>
 #include <linux/uaccess.h>
 #include <linux/memory.h>
+#include <linux/memory-tiers.h>
 
 #include "internal.h"
 
@@ -2134,17 +2135,87 @@ bool apply_policy_zone(struct mempolicy *policy, enum zone_type zone)
 	return zone >= dynamic_policy_zone;
 }
 
+/**
+ * policy_resolve_package_nodes - Restrict policy nodes to the current package
+ * @policy: Target mempolicy whose user-selected nodes are in @policy->nodes.
+ * @mask:   Output nodemask. On success, contains policy->nodes limited to
+ *          the package that should be used for the allocation.
+ *
+ * This helper combines two constraints to decide where within a socket/package
+ * memory may be allocated:
+ *
+ *   1) The caller's package: derived via mp_get_package_nodes(numa_node_id()).
+ *   2) The user's preselected set @policy->nodes (cpusets/mempolicy).
+ *
+ * The function obtains the nodemask of the current CPU's package and
+ * intersects it with @policy->nodes. If the intersection is empty (e.g. the
+ * user excluded every node of the current package), it falls back to the
+ * node in @policy->nodes, derives that node's package, and intersects
+ * again. If the fallback also yields an empty set, @mask stays empty and a
+ * non-zero error is returned.
+ *
+ * Examples (packages: P0={CPU:0, MEM:2}, P1={CPU:1, MEM:3}):
+ *   - policy->nodes = {0,1,2,3}
+ *       on P0: mask = {0,2}; on P1: mask = {1,3}.
+ *   - policy->nodes = {0,1,3}
+ *       on P0: mask = {0}      (only node 0 from P0 is allowed).
+ *   - policy->nodes = {1,2,3}
+ *       on P0: mask = {2}      (only node 2 from P0 is allowed).
+ *   - policy->nodes = {1,3}
+ *       on P0: current package (P0) & policy = NULL -> fallback to policy=1,
+ *               package(1)=P1, mask = {1,3}. (User effectively opted out of P0.)
+ *
+ * Return:
+ *   0 on success with @mask set as above;
+ *   -EINVAL if @policy/@mask is NULL;
+ *   Propagated error from mp_get_package_nodes() on failure.
+ */
+static int policy_resolve_package_nodes(struct mempolicy *policy, nodemask_t *mask)
+{
+	unsigned int node, ret = 0;
+	nodemask_t package_mask;
+
+	if (!policy || !mask)
+		return -EINVAL;
+
+	nodes_clear(*mask);
+
+	node = numa_node_id();
+	ret = mp_get_package_nodes(node, &package_mask);
+	if (!ret) {
+		nodes_and(*mask, package_mask, policy->nodes);
+
+		if (nodes_empty(*mask)) {
+			node = first_node(policy->nodes);
+			ret = mp_get_package_nodes(node, &package_mask);
+			if (ret)
+				goto out;
+			nodes_and(*mask, package_mask, policy->nodes);
+			if (nodes_empty(*mask))
+				goto out;
+		}
+	}
+
+out:
+	return ret;
+}
+
 static unsigned int weighted_interleave_nodes(struct mempolicy *policy)
 {
 	unsigned int node;
 	unsigned int cpuset_mems_cookie;
+	nodemask_t mask;
 
 retry:
 	/* to prevent miscount use tsk->mems_allowed_seq to detect rebind */
 	cpuset_mems_cookie = read_mems_allowed_begin();
 	node = current->il_prev;
-	if (!current->il_weight || !node_isset(node, policy->nodes)) {
-		node = next_node_in(node, policy->nodes);
+
+	if (policy_resolve_package_nodes(policy, &mask))
+		mask = policy->nodes;
+
+	if (!current->il_weight || !node_isset(node, mask)) {
+		node = next_node_in(node, mask);
 		if (read_mems_allowed_retry(cpuset_mems_cookie))
 			goto retry;
 		if (node == MAX_NUMNODES)
@@ -2237,6 +2308,21 @@ static unsigned int read_once_policy_nodemask(struct mempolicy *pol,
 	return nodes_weight(*mask);
 }
 
+static unsigned int read_once_policy_package_nodemask(struct mempolicy *pol,
+						      nodemask_t *mask)
+{
+	nodemask_t package_mask;
+
+	barrier();
+	if (policy_resolve_package_nodes(pol, &package_mask))
+		memcpy(mask, &pol->nodes, sizeof(nodemask_t));
+	else
+		memcpy(mask, &package_mask, sizeof(nodemask_t));
+	barrier();
+
+	return nodes_weight(*mask);
+}
+
 static unsigned int weighted_interleave_nid(struct mempolicy *pol, pgoff_t ilx)
 {
 	struct weighted_interleave_state *state;
@@ -2247,7 +2333,7 @@ static unsigned int weighted_interleave_nid(struct mempolicy *pol, pgoff_t ilx)
 	u8 weight;
 	int nid = 0;
 
-	nr_nodes = read_once_policy_nodemask(pol, &nodemask);
+	nr_nodes = read_once_policy_package_nodemask(pol, &nodemask);
 	if (!nr_nodes)
 		return numa_node_id();
 
@@ -2691,7 +2777,7 @@ static unsigned long alloc_pages_bulk_weighted_interleave(gfp_t gfp,
 	/* read the nodes onto the stack, retry if done during rebind */
 	do {
 		cpuset_mems_cookie = read_mems_allowed_begin();
-		nnodes = read_once_policy_nodemask(pol, &nodes);
+		nnodes = read_once_policy_package_nodemask(pol, &nodes);
 	} while (read_mems_allowed_retry(cpuset_mems_cookie));
 
 	/* if the nodemask has become invalid, we cannot do anything */
-- 
2.34.1