[PATCH RFC 3/4] x86/ret-thunk: Support CALL-ing to the ret-thunk

Andrew Cooper posted 4 patches 2 years, 3 months ago
[PATCH RFC 3/4] x86/ret-thunk: Support CALL-ing to the ret-thunk
Posted by Andrew Cooper 2 years, 3 months ago
This will be used to improve the SRSO mitigation.

Signed-off-by: Andrew Cooper <andrew.cooper3@citrix.com>
---
CC: x86@kernel.org
CC: linux-kernel@vger.kernel.org
CC: Borislav Petkov <bp@alien8.de>
CC: Peter Zijlstra <peterz@infradead.org>
CC: Josh Poimboeuf <jpoimboe@kernel.org>
CC: Babu Moger <babu.moger@amd.com>
CC: David.Kaplan@amd.com
CC: Nikolay Borisov <nik.borisov@suse.com>
CC: gregkh@linuxfoundation.org
CC: Thomas Gleixner <tglx@linutronix.de>

RFC: __static_call_transform() with Jcc interpreted as RET isn't safe with a
transformation to CALL.  Where does this pattern come from?
---
 arch/x86/include/asm/nospec-branch.h |  1 +
 arch/x86/kernel/alternative.c        |  4 +++-
 arch/x86/kernel/cpu/bugs.c           |  1 +
 arch/x86/kernel/ftrace.c             |  8 +++++---
 arch/x86/kernel/static_call.c        | 10 ++++++----
 arch/x86/net/bpf_jit_comp.c          |  5 ++++-
 6 files changed, 20 insertions(+), 9 deletions(-)

diff --git a/arch/x86/include/asm/nospec-branch.h b/arch/x86/include/asm/nospec-branch.h
index a4c686bc4b1f..5d5677bcf749 100644
--- a/arch/x86/include/asm/nospec-branch.h
+++ b/arch/x86/include/asm/nospec-branch.h
@@ -360,6 +360,7 @@ extern void entry_untrain_ret(void);
 extern void entry_ibpb(void);
 
 extern void (*x86_return_thunk)(void);
+extern bool x86_return_thunk_use_call;
 
 #ifdef CONFIG_CALL_DEPTH_TRACKING
 extern void __x86_return_skl(void);
diff --git a/arch/x86/kernel/alternative.c b/arch/x86/kernel/alternative.c
index 099d58d02a26..215793fa53f5 100644
--- a/arch/x86/kernel/alternative.c
+++ b/arch/x86/kernel/alternative.c
@@ -704,8 +704,10 @@ static int patch_return(void *addr, struct insn *insn, u8 *bytes)
 
 	/* Patch the custom return thunks... */
 	if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
+		u8 op = x86_return_thunk_use_call ? CALL_INSN_OPCODE : JMP32_INSN_OPCODE;
+
 		i = JMP32_INSN_SIZE;
-		__text_gen_insn(bytes, JMP32_INSN_OPCODE, addr, x86_return_thunk, i);
+		__text_gen_insn(bytes, op, addr, x86_return_thunk, i);
 	} else {
 		/* ... or patch them out if not needed. */
 		bytes[i++] = RET_INSN_OPCODE;
diff --git a/arch/x86/kernel/cpu/bugs.c b/arch/x86/kernel/cpu/bugs.c
index 893d14a9f282..de2f84aa526f 100644
--- a/arch/x86/kernel/cpu/bugs.c
+++ b/arch/x86/kernel/cpu/bugs.c
@@ -64,6 +64,7 @@ EXPORT_SYMBOL_GPL(x86_pred_cmd);
 static DEFINE_MUTEX(spec_ctrl_mutex);
 
 void (*x86_return_thunk)(void) __ro_after_init = &__x86_return_thunk;
+bool x86_return_thunk_use_call __ro_after_init;
 
 /* Update SPEC_CTRL MSR and its cached copy unconditionally */
 static void update_spec_ctrl(u64 val)
diff --git a/arch/x86/kernel/ftrace.c b/arch/x86/kernel/ftrace.c
index 12df54ff0e81..f383e4a90ce2 100644
--- a/arch/x86/kernel/ftrace.c
+++ b/arch/x86/kernel/ftrace.c
@@ -363,9 +363,11 @@ create_trampoline(struct ftrace_ops *ops, unsigned int *tramp_size)
 		goto fail;
 
 	ip = trampoline + size;
-	if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
-		__text_gen_insn(ip, JMP32_INSN_OPCODE, ip, x86_return_thunk, JMP32_INSN_SIZE);
-	else
+	if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
+		u8 op = x86_return_thunk_use_call ? CALL_INSN_OPCODE : JMP32_INSN_OPCODE;
+
+		__text_gen_insn(ip, op, ip, x86_return_thunk, JMP32_INSN_SIZE);
+	} else
 		memcpy(ip, retq, sizeof(retq));
 
 	/* No need to test direct calls on created trampolines */
diff --git a/arch/x86/kernel/static_call.c b/arch/x86/kernel/static_call.c
index 77a9316da435..b8ff0fdfa49e 100644
--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -81,9 +81,11 @@ static void __ref __static_call_transform(void *insn, enum insn_type type,
 		break;
 
 	case RET:
-		if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
-			code = text_gen_insn(JMP32_INSN_OPCODE, insn, x86_return_thunk);
-		else
+		if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
+			u8 op = x86_return_thunk_use_call ? CALL_INSN_OPCODE : JMP32_INSN_OPCODE;
+
+			code = text_gen_insn(op, insn, x86_return_thunk);
+		} else
 			code = &retinsn;
 		break;
 
@@ -91,7 +93,7 @@ static void __ref __static_call_transform(void *insn, enum insn_type type,
 		if (!func) {
 			func = __static_call_return;
 			if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
-				func = x86_return_thunk;
+				func = x86_return_thunk; /* XXX */
 		}
 
 		buf[0] = 0x0f;
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 438adb695daa..8e61a97b6d67 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -443,7 +443,10 @@ static void emit_return(u8 **pprog, u8 *ip)
 	u8 *prog = *pprog;
 
 	if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
-		emit_jump(&prog, x86_return_thunk, ip);
+		if (x86_return_thunk_use_call)
+			emit_call(&prog, x86_return_thunk, ip);
+		else
+			emit_jump(&prog, x86_return_thunk, ip);
 	} else {
 		EMIT1(0xC3);		/* ret */
 		if (IS_ENABLED(CONFIG_SLS))
-- 
2.30.2