[PATCH v6 2/5] lib: fix memparse() to handle overflow

Dmitry Antipov posted 5 patches 12 hours ago
[PATCH v6 2/5] lib: fix memparse() to handle overflow
Posted by Dmitry Antipov 12 hours ago
Since '_parse_integer_limit()' (and so 'simple_strtoull()') is now
capable to handle overflow, adjust 'memparse()' to handle overflow
(denoted by ULLONG_MAX) returned from 'simple_strtoull()'. Also
use 'check_shl_overflow()' to catch an overflow possibly caused
by processing size suffix and denote it with ULLONG_MAX as well.

Signed-off-by: Dmitry Antipov <dmantipov@yandex.ru>
---
v6: handle valid-suffix-only string like "k"
    as unrecognized, minor style adjustments
v5: initial version to join the series
---
 lib/cmdline.c | 32 +++++++++++++++++++++++++-------
 1 file changed, 25 insertions(+), 7 deletions(-)

diff --git a/lib/cmdline.c b/lib/cmdline.c
index 90ed997d9570..0d8770a0fb67 100644
--- a/lib/cmdline.c
+++ b/lib/cmdline.c
@@ -150,39 +150,57 @@ EXPORT_SYMBOL(get_options);
 unsigned long long memparse(const char *ptr, char **retptr)
 {
 	char *endptr;	/* local pointer to end of parsed string */
-
 	unsigned long long ret = simple_strtoull(ptr, &endptr, 0);
+	unsigned int shl = 0;
 
+	/* Consume valid suffix even in case of overflow. */
 	switch (*endptr) {
 	case 'E':
 	case 'e':
-		ret <<= 10;
+		shl += 10;
 		fallthrough;
 	case 'P':
 	case 'p':
-		ret <<= 10;
+		shl += 10;
 		fallthrough;
 	case 'T':
 	case 't':
-		ret <<= 10;
+		shl += 10;
 		fallthrough;
 	case 'G':
 	case 'g':
-		ret <<= 10;
+		shl += 10;
 		fallthrough;
 	case 'M':
 	case 'm':
-		ret <<= 10;
+		shl += 10;
 		fallthrough;
 	case 'K':
 	case 'k':
-		ret <<= 10;
+		shl += 10;
 		endptr++;
 		fallthrough;
 	default:
 		break;
 	}
 
+	if (shl) {
+		/* Valid suffix without preceding number. */
+		if (unlikely(ptr == endptr - 1)) {
+			endptr--;
+			ret = 0;
+		}
+		/* Apply suffix if no overflow. */
+		else if (likely(ret != ULLONG_MAX)) {
+			unsigned long long val;
+
+			if (unlikely(check_shl_overflow(ret, shl, &val)))
+				ret = ULLONG_MAX;
+			else
+				ret = val;
+		}
+	}
+
 	if (retptr)
 		*retptr = endptr;
 
-- 
2.53.0