[PATCH 17/36] lib/crypto: x86/aes: Add AES-NI optimization

Eric Biggers posted 36 patches 3 days, 11 hours ago
[PATCH 17/36] lib/crypto: x86/aes: Add AES-NI optimization
Posted by Eric Biggers 3 days, 11 hours ago
Optimize the AES library with x86 AES-NI instructions.

The relevant existing assembly functions, aesni_set_key(), aesni_enc(),
and aesni_dec(), are a bit difficult to extract into the library:

- They're coupled to the code for the AES modes.
- They operate on struct crypto_aes_ctx.  The AES library now uses
  different structs.
- They assume the key is 16-byte aligned.  The AES library only
  *prefers* 16-byte alignment; it doesn't require it.

Moreover, they're not all that great in the first place:

- They use unrolled loops, which isn't a great choice on x86.
- They use the 'aeskeygenassist' instruction, which is unnecessary, is
  slow on Intel CPUs, and forces the loop to be unrolled.
- They have special code for AES-192 key expansion, despite that being
  kind of useless.  AES-128 and AES-256 are the ones used in practice.

These are small functions anyway.

Therefore, I opted to just write replacements of these functions for the
library.  They address all the above issues.

Signed-off-by: Eric Biggers <ebiggers@kernel.org>
---
 lib/crypto/Kconfig         |   1 +
 lib/crypto/Makefile        |   1 +
 lib/crypto/x86/aes-aesni.S | 261 +++++++++++++++++++++++++++++++++++++
 lib/crypto/x86/aes.h       |  85 ++++++++++++
 4 files changed, 348 insertions(+)
 create mode 100644 lib/crypto/x86/aes-aesni.S
 create mode 100644 lib/crypto/x86/aes.h

diff --git a/lib/crypto/Kconfig b/lib/crypto/Kconfig
index 222887c04240..e3ee31217988 100644
--- a/lib/crypto/Kconfig
+++ b/lib/crypto/Kconfig
@@ -19,10 +19,11 @@ config CRYPTO_LIB_AES_ARCH
 	default y if PPC && (SPE || (PPC64 && VSX))
 	default y if RISCV && 64BIT && TOOLCHAIN_HAS_VECTOR_CRYPTO && \
 		     RISCV_EFFICIENT_VECTOR_UNALIGNED_ACCESS
 	default y if S390
 	default y if SPARC64
+	default y if X86
 
 config CRYPTO_LIB_AESCFB
 	tristate
 	select CRYPTO_LIB_AES
 	select CRYPTO_LIB_UTILS
diff --git a/lib/crypto/Makefile b/lib/crypto/Makefile
index 761d52d91f92..725eef05b758 100644
--- a/lib/crypto/Makefile
+++ b/lib/crypto/Makefile
@@ -50,10 +50,11 @@ OBJECT_FILES_NON_STANDARD_powerpc/aesp8-ppc.o := y
 endif # !CONFIG_SPE
 endif # CONFIG_PPC
 
 libaes-$(CONFIG_RISCV) += riscv/aes-riscv64-zvkned.o
 libaes-$(CONFIG_SPARC) += sparc/aes_asm.o
+libaes-$(CONFIG_X86) += x86/aes-aesni.o
 endif # CONFIG_CRYPTO_LIB_AES_ARCH
 
 ################################################################################
 
 obj-$(CONFIG_CRYPTO_LIB_AESCFB)			+= libaescfb.o
diff --git a/lib/crypto/x86/aes-aesni.S b/lib/crypto/x86/aes-aesni.S
new file mode 100644
index 000000000000..b8c3e104a3be
--- /dev/null
+++ b/lib/crypto/x86/aes-aesni.S
@@ -0,0 +1,261 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+//
+// AES block cipher using AES-NI instructions
+//
+// Copyright 2026 Google LLC
+//
+// The code in this file supports 32-bit and 64-bit CPUs, and it doesn't require
+// AVX.  It does use up to SSE4.1, which all CPUs with AES-NI have.
+#include <linux/linkage.h>
+
+.section .rodata
+#ifdef __x86_64__
+#define RODATA(label)	label(%rip)
+#else
+#define RODATA(label)	label
+#endif
+
+	// A mask for pshufb that extracts the last dword, rotates it right by 8
+	// bits, and copies the result to all four dwords.
+.p2align 4
+.Lmask:
+	.byte	13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12
+
+	// The AES round constants, used during key expansion
+.Lrcon:
+	.long	0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36
+
+.text
+
+// Transform four dwords [a0, a1, a2, a3] in \a into
+// [a0, a0^a1, a0^a1^a2, a0^a1^a2^a3].  \tmp is a temporary xmm register.
+//
+// Note: this could be done in four instructions, shufps + pxor + shufps + pxor,
+// if the temporary register were zero-initialized ahead of time.  We instead do
+// it in an easier-to-understand way that doesn't require zero-initialization
+// and avoids the unusual shufps instruction.  movdqa is usually "free" anyway.
+.macro	_prefix_sum	a, tmp
+	movdqa		\a, \tmp	// [a0, a1, a2, a3]
+	pslldq		$4, \a		// [0, a0, a1, a2]
+	pxor		\tmp, \a	// [a0, a0^a1, a1^a2, a2^a3]
+	movdqa		\a, \tmp
+	pslldq		$8, \a		// [0, 0, a0, a0^a1]
+	pxor		\tmp, \a	// [a0, a0^a1, a0^a1^a2, a0^a1^a2^a3]
+.endm
+
+.macro	_gen_round_key	a, b
+	// Compute four copies of rcon[i] ^ SubBytes(ror32(w, 8)), where w is
+	// the last dword of the previous round key (given in \b).
+	//
+	// 'aesenclast src, dst' does dst = src XOR SubBytes(ShiftRows(dst)).
+	// It is used here solely for the SubBytes and the XOR.  The ShiftRows
+	// is a no-op because all four columns are the same here.
+	//
+	// Don't use the 'aeskeygenassist' instruction, since:
+	//  - On most Intel CPUs it is microcoded, making it have a much higher
+	//    latency and use more execution ports than 'aesenclast'.
+	//  - It cannot be used in a loop, since it requires an immediate.
+	//  - It doesn't do much more than 'aesenclast' in the first place.
+	movdqa		\b, %xmm2
+	pshufb		MASK, %xmm2
+	aesenclast	RCON, %xmm2
+
+	// XOR in the prefix sum of the four dwords of \a, which is the
+	// previous round key (AES-128) or the first round key in the previous
+	// pair of round keys (AES-256).  The result is the next round key.
+	_prefix_sum	\a, tmp=%xmm3
+	pxor		%xmm2, \a
+
+	// Store the next round key to memory.  Also leave it in \a.
+	movdqu		\a, (RNDKEYS)
+.endm
+
+.macro	_aes_expandkey_aesni	is_aes128
+#ifdef __x86_64__
+	// Arguments
+	.set	RNDKEYS,	%rdi
+	.set	INV_RNDKEYS,	%rsi
+	.set	IN_KEY,		%rdx
+
+	// Other local variables
+	.set	RCON_PTR,	%rcx
+	.set	COUNTER,	%eax
+#else
+	// Arguments, assuming -mregparm=3
+	.set	RNDKEYS,	%eax
+	.set	INV_RNDKEYS,	%edx
+	.set	IN_KEY,		%ecx
+
+	// Other local variables
+	.set	RCON_PTR,	%ebx
+	.set	COUNTER,	%esi
+#endif
+	.set	RCON,		%xmm6
+	.set	MASK,		%xmm7
+
+#ifdef __i386__
+	push		%ebx
+	push		%esi
+#endif
+
+.if \is_aes128
+	// AES-128: the first round key is simply a copy of the raw key.
+	movdqu		(IN_KEY), %xmm0
+	movdqu		%xmm0, (RNDKEYS)
+.else
+	// AES-256: the first two round keys are simply a copy of the raw key.
+	movdqu		(IN_KEY), %xmm0
+	movdqu		%xmm0, (RNDKEYS)
+	movdqu		16(IN_KEY), %xmm1
+	movdqu		%xmm1, 16(RNDKEYS)
+	add		$32, RNDKEYS
+.endif
+
+	// Generate the remaining round keys.
+	movdqa		RODATA(.Lmask), MASK
+.if \is_aes128
+	lea		RODATA(.Lrcon), RCON_PTR
+	mov		$10, COUNTER
+.Lgen_next_aes128_round_key:
+	add		$16, RNDKEYS
+	movd		(RCON_PTR), RCON
+	pshufd		$0x00, RCON, RCON
+	add		$4, RCON_PTR
+	_gen_round_key	%xmm0, %xmm0
+	dec		COUNTER
+	jnz		.Lgen_next_aes128_round_key
+.else
+	// AES-256: only the first 7 round constants are needed, so instead of
+	// loading each one from memory, just start by loading [1, 1, 1, 1] and
+	// then generate the rest by doubling.
+	pshufd		$0x00, RODATA(.Lrcon), RCON
+	pxor		%xmm5, %xmm5	// All-zeroes
+	mov		$7, COUNTER
+.Lgen_next_aes256_round_key_pair:
+	// Generate the next AES-256 round key: either the first of a pair of
+	// two, or the last one.
+	_gen_round_key	%xmm0, %xmm1
+
+	dec		COUNTER
+	jz		.Lgen_aes256_round_keys_done
+
+	// Generate the second AES-256 round key of the pair.  Compared to the
+	// first, there's no rotation and no XOR of a round constant.
+	pshufd		$0xff, %xmm0, %xmm2	// Get four copies of last dword
+	aesenclast	%xmm5, %xmm2		// Just does SubBytes
+	_prefix_sum	%xmm1, tmp=%xmm3
+	pxor		%xmm2, %xmm1
+	movdqu		%xmm1, 16(RNDKEYS)
+	add		$32, RNDKEYS
+	paddd		RCON, RCON		// RCON <<= 1
+	jmp		.Lgen_next_aes256_round_key_pair
+.Lgen_aes256_round_keys_done:
+.endif
+
+	// If INV_RNDKEYS is non-NULL, write the round keys for the Equivalent
+	// Inverse Cipher to it.  To do that, reverse the standard round keys,
+	// and apply aesimc (InvMixColumn) to each except the first and last.
+	test		INV_RNDKEYS, INV_RNDKEYS
+	jz		.Ldone\@
+	movdqu		(RNDKEYS), %xmm0	// Last standard round key
+	movdqu		%xmm0, (INV_RNDKEYS)	// => First inverse round key
+.if \is_aes128
+	mov		$9, COUNTER
+.else
+	mov		$13, COUNTER
+.endif
+.Lgen_next_inv_round_key\@:
+	sub		$16, RNDKEYS
+	add		$16, INV_RNDKEYS
+	movdqu		(RNDKEYS), %xmm0
+	aesimc		%xmm0, %xmm0
+	movdqu		%xmm0, (INV_RNDKEYS)
+	dec		COUNTER
+	jnz		.Lgen_next_inv_round_key\@
+	movdqu		-16(RNDKEYS), %xmm0	// First standard round key
+	movdqu		%xmm0, 16(INV_RNDKEYS)	// => Last inverse round key
+
+.Ldone\@:
+#ifdef __i386__
+	pop		%esi
+	pop		%ebx
+#endif
+	RET
+.endm
+
+// void aes128_expandkey_aesni(u32 rndkeys[], u32 *inv_rndkeys,
+//			       const u8 in_key[AES_KEYSIZE_128]);
+SYM_FUNC_START(aes128_expandkey_aesni)
+	_aes_expandkey_aesni	1
+SYM_FUNC_END(aes128_expandkey_aesni)
+
+// void aes256_expandkey_aesni(u32 rndkeys[], u32 *inv_rndkeys,
+//			       const u8 in_key[AES_KEYSIZE_256]);
+SYM_FUNC_START(aes256_expandkey_aesni)
+	_aes_expandkey_aesni	0
+SYM_FUNC_END(aes256_expandkey_aesni)
+
+.macro	_aes_crypt_aesni	enc
+#ifdef __x86_64__
+	.set	RNDKEYS,	%rdi
+	.set	NROUNDS,	%esi
+	.set	OUT,		%rdx
+	.set	IN,		%rcx
+#else
+	// Assuming -mregparm=3
+	.set	RNDKEYS,	%eax
+	.set	NROUNDS,	%edx
+	.set	OUT,		%ecx
+	.set	IN,		%ebx	// Passed on stack
+#endif
+
+#ifdef __i386__
+	push		%ebx
+	mov		8(%esp), %ebx
+#endif
+
+	// Zero-th round
+	movdqu		(IN), %xmm0
+	movdqu		(RNDKEYS), %xmm1
+	pxor		%xmm1, %xmm0
+
+	// Normal rounds
+	add		$16, RNDKEYS
+	dec		NROUNDS
+.Lnext_round\@:
+	movdqu		(RNDKEYS), %xmm1
+.if \enc
+	aesenc		%xmm1, %xmm0
+.else
+	aesdec		%xmm1, %xmm0
+.endif
+	add		$16, RNDKEYS
+	dec		NROUNDS
+	jne		.Lnext_round\@
+
+	// Last round
+	movdqu		(RNDKEYS), %xmm1
+.if \enc
+	aesenclast	%xmm1, %xmm0
+.else
+	aesdeclast	%xmm1, %xmm0
+.endif
+	movdqu		%xmm0, (OUT)
+
+#ifdef __i386__
+	pop		%ebx
+#endif
+	RET
+.endm
+
+// void aes_encrypt_aesni(const u32 rndkeys[], int nrounds,
+//			  u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
+SYM_FUNC_START(aes_encrypt_aesni)
+	_aes_crypt_aesni	1
+SYM_FUNC_END(aes_encrypt_aesni)
+
+// void aes_decrypt_aesni(const u32 inv_rndkeys[], int nrounds,
+//			  u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
+SYM_FUNC_START(aes_decrypt_aesni)
+	_aes_crypt_aesni	0
+SYM_FUNC_END(aes_decrypt_aesni)
diff --git a/lib/crypto/x86/aes.h b/lib/crypto/x86/aes.h
new file mode 100644
index 000000000000..b047dee94f57
--- /dev/null
+++ b/lib/crypto/x86/aes.h
@@ -0,0 +1,85 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+/*
+ * AES block cipher using AES-NI instructions
+ *
+ * Copyright 2026 Google LLC
+ */
+
+#include <asm/fpu/api.h>
+
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_aes);
+
+void aes128_expandkey_aesni(u32 rndkeys[], u32 *inv_rndkeys,
+			    const u8 in_key[AES_KEYSIZE_128]);
+void aes256_expandkey_aesni(u32 rndkeys[], u32 *inv_rndkeys,
+			    const u8 in_key[AES_KEYSIZE_256]);
+void aes_encrypt_aesni(const u32 rndkeys[], int nrounds,
+		       u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
+void aes_decrypt_aesni(const u32 inv_rndkeys[], int nrounds,
+		       u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
+
+/*
+ * Expand an AES key using AES-NI if supported and usable or generic code
+ * otherwise.  The expanded key format is compatible between the two cases.  The
+ * outputs are @k->rndkeys (required) and @inv_k->inv_rndkeys (optional).
+ *
+ * We could just always use the generic key expansion code.  AES key expansion
+ * is usually less performance-critical than AES en/decryption.  However,
+ * there's still *some* value in speed here, as well as in non-key-dependent
+ * execution time which AES-NI provides.  So, do use AES-NI to expand AES-128
+ * and AES-256 keys.  (Don't bother with AES-192, as it's almost never used.)
+ */
+static void aes_preparekey_arch(union aes_enckey_arch *k,
+				union aes_invkey_arch *inv_k,
+				const u8 *in_key, int key_len, int nrounds)
+{
+	u32 *rndkeys = k->rndkeys;
+	u32 *inv_rndkeys = inv_k ? inv_k->inv_rndkeys : NULL;
+
+	if (static_branch_likely(&have_aes) && key_len != AES_KEYSIZE_192 &&
+	    irq_fpu_usable()) {
+		kernel_fpu_begin();
+		if (key_len == AES_KEYSIZE_128)
+			aes128_expandkey_aesni(rndkeys, inv_rndkeys, in_key);
+		else
+			aes256_expandkey_aesni(rndkeys, inv_rndkeys, in_key);
+		kernel_fpu_end();
+	} else {
+		aes_expandkey_generic(rndkeys, inv_rndkeys, in_key, key_len);
+	}
+}
+
+static void aes_encrypt_arch(const struct aes_enckey *key,
+			     u8 out[AES_BLOCK_SIZE],
+			     const u8 in[AES_BLOCK_SIZE])
+{
+	if (static_branch_likely(&have_aes) && irq_fpu_usable()) {
+		kernel_fpu_begin();
+		aes_encrypt_aesni(key->k.rndkeys, key->nrounds, out, in);
+		kernel_fpu_end();
+	} else {
+		aes_encrypt_generic(key->k.rndkeys, key->nrounds, out, in);
+	}
+}
+
+static void aes_decrypt_arch(const struct aes_key *key,
+			     u8 out[AES_BLOCK_SIZE],
+			     const u8 in[AES_BLOCK_SIZE])
+{
+	if (static_branch_likely(&have_aes) && irq_fpu_usable()) {
+		kernel_fpu_begin();
+		aes_decrypt_aesni(key->inv_k.inv_rndkeys, key->nrounds,
+				  out, in);
+		kernel_fpu_end();
+	} else {
+		aes_decrypt_generic(key->inv_k.inv_rndkeys, key->nrounds,
+				    out, in);
+	}
+}
+
+#define aes_mod_init_arch aes_mod_init_arch
+static void aes_mod_init_arch(void)
+{
+	if (boot_cpu_has(X86_FEATURE_AES))
+		static_branch_enable(&have_aes);
+}
-- 
2.52.0