[PATCH 1/2] overflow: make check_shl_overflow() 128-bit aware

Rafael V. Volkmer posted 2 patches 3 days, 17 hours ago
[PATCH 1/2] overflow: make check_shl_overflow() 128-bit aware
Posted by Rafael V. Volkmer 3 days, 17 hours ago
check_shl_overflow() currently evaluates (@a << @s) in an
unsigned long long accumulator. When callers pass __int128/u128
values, the intermediate is truncated to 64 bits before the
comparison, so the helper always reports overflow and returns a
zeroed result even when *@d is wide enough to hold the full shift.

Introduce __shl_eval_type() to derive the internal evaluation type
from @a and *@d. On architectures with CONFIG_ARCH_SUPPORTS_INT128
and compiler support for __int128, it promotes the accumulator to
u128 when the promoted sum of @a and *@d is wider than 64 bits;
otherwise it stays in an unsigned 64-bit type.

This keeps the accumulator unsigned (avoiding UB when left-shifting
negative signed values), preserves existing code generation for all
current 32/64-bit users, and fixes the spurious overflow reporting
for 128-bit shift users.

Signed-off-by: Rafael V. Volkmer <rafael.v.volkmer@gmail.com>
---
 include/linux/overflow.h | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/include/linux/overflow.h b/include/linux/overflow.h
index 725f95f7e416..ca8252e625d5 100644
--- a/include/linux/overflow.h
+++ b/include/linux/overflow.h
@@ -175,6 +175,26 @@ static inline bool __must_check __must_check_overflow(bool overflow)
 		__val;						\
 	})
 
+/**
+ * __shl_eval_type() - Choose evaluation type for shift checks
+ * @a: value to be shifted
+ * @d: destination pointer
+ *
+ * Returns the internal type used by check_shl_overflow() to evaluate
+ * (@a << @s), widening to unsigned __int128 when available and either
+ * @a or *@d promote wider than 64 bits, otherwise using unsigned long long.
+ */
+#if defined(__SIZEOF_INT128__)
+#define __shl_eval_type(a, d)							\
+	typeof(__builtin_choose_expr(						\
+		sizeof((a) + (typeof(*(d)))0) > sizeof(unsigned long long),	\
+		(unsigned __int128)0,						\
+		0ULL))
+#else
+#define __shl_eval_type(a, d)							\
+	typeof(0ULL + (a) + (typeof(*(d)))0)
+#endif
+
 /**
  * check_shl_overflow() - Calculate a left-shifted value and check overflow
  * @a: Value to be shifted
@@ -199,7 +219,7 @@ static inline bool __must_check __must_check_overflow(bool overflow)
 	typeof(a) _a = a;						\
 	typeof(s) _s = s;						\
 	typeof(d) _d = d;						\
-	unsigned long long _a_full = _a;				\
+	__shl_eval_type(_a, _d) _a_full = (__shl_eval_type(_a, _d))_a;	\
 	unsigned int _to_shift =					\
 		is_non_negative(_s) && _s < 8 * sizeof(*d) ? _s : 0;	\
 	*_d = (_a_full << _to_shift);					\
-- 
2.43.0