[PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest

Adrian Hunter posted 7 patches 1 year, 2 months ago
There is a newer version of this series
[PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
From: Kai Huang <kai.huang@intel.com>

Intel TDX protects guest VM's from malicious host and certain physical
attacks.  TDX introduces a new operation mode, Secure Arbitration Mode
(SEAM) to isolate and protect guest VM's.  A TDX guest VM runs in SEAM and,
unlike VMX, direct control and interaction with the guest by the host VMM
is not possible.  Instead, Intel TDX Module, which also runs in SEAM,
provides a SEAMCALL API.

The SEAMCALL that provides the ability to enter a guest is TDH.VP.ENTER.
The TDX Module processes TDH.VP.ENTER, and enters the guest via VMX
VMLAUNCH/VMRESUME instructions.  When a guest VM-exit requires host VMM
interaction, the TDH.VP.ENTER SEAMCALL returns to the host VMM (KVM).

Add tdh_vp_enter() to wrap the SEAMCALL invocation of TDH.VP.ENTER.

TDH.VP.ENTER is different from other SEAMCALLS in several ways:
 - it may take some time to return as the guest executes
 - it uses more arguments
 - after it returns some host state may need to be restored

TDH.VP.ENTER arguments are passed through General Purpose Registers (GPRs).
For the special case of the TD guest invoking TDG.VP.VMCALL, nearly any GPR
can be used, as well as XMM0 to XMM15. Notably, RBP is not used, and Linux
mandates the TDX Module feature NO_RBP_MOD, which is enforced elsewhere.
Additionally, XMM registers are not required for the existing Guest
Hypervisor Communication Interface and are handled by existing KVM code
should they be modified by the guest.

There are 2 input formats and 5 output formats for TDH.VP.ENTER arguments.
Input #1 : Initial entry or following a previous async. TD Exit
Input #2 : Following a previous TDCALL(TDG.VP.VMCALL)
Output #1 : On Error (No TD Entry)
Output #2 : Async. Exits with a VMX Architectural Exit Reason
Output #3 : Async. Exits with a non-VMX TD Exit Status
Output #4 : Async. Exits with Cross-TD Exit Details
Output #5 : On TDCALL(TDG.VP.VMCALL)

Currently, to keep things simple, the wrapper function does not attempt
to support different formats, and just passes all the GPRs that could be
used.  The GPR values are held by KVM in the area set aside for guest
GPRs.  KVM code uses the guest GPR area (vcpu->arch.regs[]) to set up for
or process results of tdh_vp_enter().

Therefore changing tdh_vp_enter() to use more complex argument formats
would also alter the way KVM code interacts with tdh_vp_enter().

Signed-off-by: Kai Huang <kai.huang@intel.com>
Signed-off-by: Adrian Hunter <adrian.hunter@intel.com>
---
 arch/x86/include/asm/tdx.h  | 1 +
 arch/x86/virt/vmx/tdx/tdx.c | 8 ++++++++
 arch/x86/virt/vmx/tdx/tdx.h | 1 +
 3 files changed, 10 insertions(+)

diff --git a/arch/x86/include/asm/tdx.h b/arch/x86/include/asm/tdx.h
index fdc81799171e..77477b905dca 100644
--- a/arch/x86/include/asm/tdx.h
+++ b/arch/x86/include/asm/tdx.h
@@ -123,6 +123,7 @@ int tdx_guest_keyid_alloc(void);
 void tdx_guest_keyid_free(unsigned int keyid);
 
 /* SEAMCALL wrappers for creating/destroying/running TDX guests */
+u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args);
 u64 tdh_mng_addcx(u64 tdr, u64 tdcs);
 u64 tdh_mem_page_add(u64 tdr, u64 gpa, u64 hpa, u64 source, u64 *rcx, u64 *rdx);
 u64 tdh_mem_sept_add(u64 tdr, u64 gpa, u64 level, u64 hpa, u64 *rcx, u64 *rdx);
diff --git a/arch/x86/virt/vmx/tdx/tdx.c b/arch/x86/virt/vmx/tdx/tdx.c
index 04cb2f1d6deb..2a8997eb1ef1 100644
--- a/arch/x86/virt/vmx/tdx/tdx.c
+++ b/arch/x86/virt/vmx/tdx/tdx.c
@@ -1600,6 +1600,14 @@ static inline u64 tdx_seamcall_sept(u64 op, struct tdx_module_args *in)
 	return ret;
 }
 
+u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
+{
+	args->rcx = tdvpr;
+
+	return __seamcall_saved_ret(TDH_VP_ENTER, args);
+}
+EXPORT_SYMBOL_GPL(tdh_vp_enter);
+
 u64 tdh_mng_addcx(u64 tdr, u64 tdcs)
 {
 	struct tdx_module_args args = {
diff --git a/arch/x86/virt/vmx/tdx/tdx.h b/arch/x86/virt/vmx/tdx/tdx.h
index 4919d00025c9..58d5754dcb4d 100644
--- a/arch/x86/virt/vmx/tdx/tdx.h
+++ b/arch/x86/virt/vmx/tdx/tdx.h
@@ -17,6 +17,7 @@
 /*
  * TDX module SEAMCALL leaf functions
  */
+#define TDH_VP_ENTER			0
 #define TDH_MNG_ADDCX			1
 #define TDH_MEM_PAGE_ADD		2
 #define TDH_MEM_SEPT_ADD		3
-- 
2.43.0
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Dave Hansen 1 year, 2 months ago
On 11/21/24 12:14, Adrian Hunter wrote:
> +u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
> +{
> +	args->rcx = tdvpr;
> +
> +	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> +}
> +EXPORT_SYMBOL_GPL(tdh_vp_enter);

I made a similar comment on another series, but it stands here too: the
typing of this wrappers really needs a closer look. Passing u64's around
everywhere means zero type safety.

Type safety is the reason that we have types like pte_t and pgprot_t in
mm code even though they're really just longs (most of the time).

I'd suggest keeping the tdx_td_page type as long as possible, probably
until (for example) the ->rcx assignment, like this:

	args->rcx = td_page.pa;
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Edgecombe, Rick P 1 year, 2 months ago
On Fri, 2024-11-22 at 08:26 -0800, Dave Hansen wrote:
> On 11/21/24 12:14, Adrian Hunter wrote:
> > +u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
> > +{
> > +	args->rcx = tdvpr;
> > +
> > +	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> > +}
> > +EXPORT_SYMBOL_GPL(tdh_vp_enter);
> 
> I made a similar comment on another series, but it stands here too: the
> typing of this wrappers really needs a closer look. Passing u64's around
> everywhere means zero type safety.
> 
> Type safety is the reason that we have types like pte_t and pgprot_t in
> mm code even though they're really just longs (most of the time).
> 
> I'd suggest keeping the tdx_td_page type as long as possible, probably
> until (for example) the ->rcx assignment, like this:
> 
> 	args->rcx = td_page.pa;

Any thoughts on the approach here to the type questions?

https://lore.kernel.org/kvm/20241115202028.1585487-1-rick.p.edgecombe@intel.com/


Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
On 22/11/24 19:29, Edgecombe, Rick P wrote:
> On Fri, 2024-11-22 at 08:26 -0800, Dave Hansen wrote:
>> On 11/21/24 12:14, Adrian Hunter wrote:
>>> +u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
>>> +{
>>> +   args->rcx = tdvpr;
>>> +
>>> +   return __seamcall_saved_ret(TDH_VP_ENTER, args);
>>> +}
>>> +EXPORT_SYMBOL_GPL(tdh_vp_enter);
>>
>> I made a similar comment on another series, but it stands here too: the
>> typing of this wrappers really needs a closer look. Passing u64's around
>> everywhere means zero type safety.
>>
>> Type safety is the reason that we have types like pte_t and pgprot_t in
>> mm code even though they're really just longs (most of the time).
>>
>> I'd suggest keeping the tdx_td_page type as long as possible, probably
>> until (for example) the ->rcx assignment, like this:
>>
>>       args->rcx = td_page.pa;
> 
> Any thoughts on the approach here to the type questions?
> 
> https://lore.kernel.org/kvm/20241115202028.1585487-1-rick.p.edgecombe@intel.com/

For tdh_vp_enter() we will just use the same approach for
tdvpr, whatever that ends up being.
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
On 21/11/24 22:14, Adrian Hunter wrote:
> From: Kai Huang <kai.huang@intel.com>
> 
> Intel TDX protects guest VM's from malicious host and certain physical
> attacks.  TDX introduces a new operation mode, Secure Arbitration Mode
> (SEAM) to isolate and protect guest VM's.  A TDX guest VM runs in SEAM and,
> unlike VMX, direct control and interaction with the guest by the host VMM
> is not possible.  Instead, Intel TDX Module, which also runs in SEAM,
> provides a SEAMCALL API.
> 
> The SEAMCALL that provides the ability to enter a guest is TDH.VP.ENTER.
> The TDX Module processes TDH.VP.ENTER, and enters the guest via VMX
> VMLAUNCH/VMRESUME instructions.  When a guest VM-exit requires host VMM
> interaction, the TDH.VP.ENTER SEAMCALL returns to the host VMM (KVM).
> 
> Add tdh_vp_enter() to wrap the SEAMCALL invocation of TDH.VP.ENTER.
> 
> TDH.VP.ENTER is different from other SEAMCALLS in several ways:
>  - it may take some time to return as the guest executes
>  - it uses more arguments
>  - after it returns some host state may need to be restored
> 
> TDH.VP.ENTER arguments are passed through General Purpose Registers (GPRs).
> For the special case of the TD guest invoking TDG.VP.VMCALL, nearly any GPR
> can be used, as well as XMM0 to XMM15. Notably, RBP is not used, and Linux
> mandates the TDX Module feature NO_RBP_MOD, which is enforced elsewhere.
> Additionally, XMM registers are not required for the existing Guest
> Hypervisor Communication Interface and are handled by existing KVM code
> should they be modified by the guest.
> 
> There are 2 input formats and 5 output formats for TDH.VP.ENTER arguments.
> Input #1 : Initial entry or following a previous async. TD Exit
> Input #2 : Following a previous TDCALL(TDG.VP.VMCALL)
> Output #1 : On Error (No TD Entry)
> Output #2 : Async. Exits with a VMX Architectural Exit Reason
> Output #3 : Async. Exits with a non-VMX TD Exit Status
> Output #4 : Async. Exits with Cross-TD Exit Details
> Output #5 : On TDCALL(TDG.VP.VMCALL)
> 
> Currently, to keep things simple, the wrapper function does not attempt
> to support different formats, and just passes all the GPRs that could be
> used.  The GPR values are held by KVM in the area set aside for guest
> GPRs.  KVM code uses the guest GPR area (vcpu->arch.regs[]) to set up for
> or process results of tdh_vp_enter().
> 
> Therefore changing tdh_vp_enter() to use more complex argument formats
> would also alter the way KVM code interacts with tdh_vp_enter().
> 
> Signed-off-by: Kai Huang <kai.huang@intel.com>
> Signed-off-by: Adrian Hunter <adrian.hunter@intel.com>
> ---
>  arch/x86/include/asm/tdx.h  | 1 +
>  arch/x86/virt/vmx/tdx/tdx.c | 8 ++++++++
>  arch/x86/virt/vmx/tdx/tdx.h | 1 +
>  3 files changed, 10 insertions(+)
> 
> diff --git a/arch/x86/include/asm/tdx.h b/arch/x86/include/asm/tdx.h
> index fdc81799171e..77477b905dca 100644
> --- a/arch/x86/include/asm/tdx.h
> +++ b/arch/x86/include/asm/tdx.h
> @@ -123,6 +123,7 @@ int tdx_guest_keyid_alloc(void);
>  void tdx_guest_keyid_free(unsigned int keyid);
>  
>  /* SEAMCALL wrappers for creating/destroying/running TDX guests */
> +u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args);
>  u64 tdh_mng_addcx(u64 tdr, u64 tdcs);
>  u64 tdh_mem_page_add(u64 tdr, u64 gpa, u64 hpa, u64 source, u64 *rcx, u64 *rdx);
>  u64 tdh_mem_sept_add(u64 tdr, u64 gpa, u64 level, u64 hpa, u64 *rcx, u64 *rdx);
> diff --git a/arch/x86/virt/vmx/tdx/tdx.c b/arch/x86/virt/vmx/tdx/tdx.c
> index 04cb2f1d6deb..2a8997eb1ef1 100644
> --- a/arch/x86/virt/vmx/tdx/tdx.c
> +++ b/arch/x86/virt/vmx/tdx/tdx.c
> @@ -1600,6 +1600,14 @@ static inline u64 tdx_seamcall_sept(u64 op, struct tdx_module_args *in)
>  	return ret;
>  }
>  
> +u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
> +{
> +	args->rcx = tdvpr;
> +
> +	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> +}
> +EXPORT_SYMBOL_GPL(tdh_vp_enter);

One alternative could be to create a union to hold the arguments:

u64 tdh_vp_enter(u64 tdvpr, union tdh_vp_enter_args *vp_enter_args)
{
	struct tdx_module_args *args = (struct tdx_module_args *)vp_enter_args;

	args->rcx = tdvpr;

	return __seamcall_saved_ret(TDH_VP_ENTER, args);
}

The diff below shows what that would look like for KVM TDX, based on top
of:

	https://github.com/intel/tdx/tree/tdx_kvm_dev-2024-11-20

Define 'union tdh_vp_enter_args' to hold tdh_vp_enter() arguments
instead of using vcpu->arch.regs[].  For example, in tdexit_exit_qual()

	kvm_rcx_read(vcpu)

becomes:

	to_tdx(vcpu)->vp_enter_args.out.exit_qual

which has the advantage that it provides variable names for the different
arguments.

---
 arch/x86/include/asm/tdx.h  | 163 +++++++++++++++++++++++++++-
 arch/x86/kvm/vmx/tdx.c      | 205 +++++++++++++++---------------------
 arch/x86/kvm/vmx/tdx.h      |   1 +
 arch/x86/virt/vmx/tdx/tdx.c |   4 +-
 4 files changed, 249 insertions(+), 124 deletions(-)

diff --git a/arch/x86/include/asm/tdx.h b/arch/x86/include/asm/tdx.h
index 01409a59224d..3568e6b36b77 100644
--- a/arch/x86/include/asm/tdx.h
+++ b/arch/x86/include/asm/tdx.h
@@ -123,8 +123,169 @@ const struct tdx_sys_info *tdx_get_sysinfo(void);
 int tdx_guest_keyid_alloc(void);
 void tdx_guest_keyid_free(unsigned int keyid);
 
+/* TDH.VP.ENTER Input Format #2 : Following a previous TDCALL(TDG.VP.VMCALL) */
+struct tdh_vp_enter_in {
+	u64	__vcpu_handle_and_flags; /* Don't use. tdh_vp_enter() will take care of it */
+	u64	unused[3];
+	u64	ret_code;
+	union {
+		u64 gettdvmcallinfo[4];
+		struct {
+			u64	failed_gpa;
+		} mapgpa;
+		struct {
+			u64	unused;
+			u64	eax;
+			u64	ebx;
+			u64	ecx;
+			u64	edx;
+		} cpuid;
+		/* Value read for IO, MMIO or RDMSR */
+		struct {
+			u64	value;
+		} read;
+	};
+};
+
+/*
+ * TDH.VP.ENTER Output Formats #2 and #3 combined:
+ *	#2 : Async TD exits with a VMX Architectural Exit Reason
+ *	#3 : Async TD exits with a non-VMX TD Exit Status
+ */
+struct tdh_vp_enter_out {
+	u64	exit_qual	: 32,	/* #2 only */
+		vm_idx		:  2,	/* #2 and #3 */
+		reserved_0	: 30;
+	u64	ext_exit_qual;		/* #2 only */
+	u64	gpa;			/* #2 only */
+	u64	interrupt_info	: 32,	/* #2 only */
+		reserved_1	: 32;
+	u64	unused[9];
+};
+
+/*
+ * KVM hypercall : Refer struct tdh_vp_enter_tdcall - fn is the non-zero
+ * hypercall number (nr), subfn is the first parameter (p1), and p2 to p3
+ * below are the remaining parameters.
+ */
+struct tdh_vp_enter_vmcall {
+	u64	p2;
+	u64	p3;
+	u64	p4;
+};
+
+/* TDVMCALL_GET_TD_VM_CALL_INFO */
+struct tdh_vp_enter_gettdvmcallinfo {
+	u64	leaf;
+};
+
+/* TDVMCALL_MAP_GPA */
+struct tdh_vp_enter_mapgpa {
+	u64	gpa;
+	u64	size;
+};
+
+/* TDVMCALL_GET_QUOTE */
+struct tdh_vp_enter_getquote {
+	u64	shared_gpa;
+	u64	size;
+};
+
+#define TDX_ERR_DATA_PART_1 5
+
+/* TDVMCALL_REPORT_FATAL_ERROR */
+struct tdh_vp_enter_reportfatalerror {
+	union {
+		u64	err_codes;
+		struct {
+			u64	err_code	: 32,
+				ext_err_code	: 31,
+				gpa_valid	:  1;
+		};
+	};
+	u64	err_data_gpa;
+	u64	err_data[TDX_ERR_DATA_PART_1];
+};
+
+/* EXIT_REASON_CPUID */
+struct tdh_vp_enter_cpuid {
+	u64	eax;
+	u64	ecx;
+};
+
+/* EXIT_REASON_EPT_VIOLATION */
+struct tdh_vp_enter_mmio {
+	u64	size;
+	u64	direction;
+	u64	mmio_addr;
+	u64	value;
+};
+
+/* EXIT_REASON_HLT */
+struct tdh_vp_enter_hlt {
+	u64	intr_blocked_flag;
+};
+
+/* EXIT_REASON_IO_INSTRUCTION */
+struct tdh_vp_enter_io {
+	u64	size;
+	u64	direction;
+	u64	port;
+	u64	value;
+};
+
+/* EXIT_REASON_MSR_READ */
+struct tdh_vp_enter_rd {
+	u64	msr;
+};
+
+/* EXIT_REASON_MSR_WRITE */
+struct  tdh_vp_enter_wr {
+	u64	msr;
+	u64	value;
+};
+
+#define TDX_ERR_DATA_PART_2 3
+
+/* TDH.VP.ENTER  Output Format #5 : On TDCALL(TDG.VP.VMCALL) */
+struct tdh_vp_enter_tdcall {
+	u64	reg_mask	: 32,
+		vm_idx		:  2,
+		reserved_0	: 30;
+	u64	data[TDX_ERR_DATA_PART_2];
+	u64	fn;	/* Non-zero for hypercalls, zero otherwise */
+	u64	subfn;
+	union {
+		struct tdh_vp_enter_vmcall 		vmcall;
+		struct tdh_vp_enter_gettdvmcallinfo	gettdvmcallinfo;
+		struct tdh_vp_enter_mapgpa		mapgpa;
+		struct tdh_vp_enter_getquote		getquote;
+		struct tdh_vp_enter_reportfatalerror	reportfatalerror;
+		struct tdh_vp_enter_cpuid		cpuid;
+		struct tdh_vp_enter_mmio		mmio;
+		struct tdh_vp_enter_hlt			hlt;
+		struct tdh_vp_enter_io			io;
+		struct tdh_vp_enter_rd			rd;
+		struct tdh_vp_enter_wr			wr;
+	};
+};
+
+/* Must be kept exactly in sync with struct tdx_module_args */
+union tdh_vp_enter_args {
+	/* Input Format #2 : Following a previous TDCALL(TDG.VP.VMCALL) */
+	struct tdh_vp_enter_in in;
+	/*
+	 * Output Formats #2 and #3 combined:
+	 *	#2 : Async TD exits with a VMX Architectural Exit Reason
+	 *	#3 : Async TD exits with a non-VMX TD Exit Status
+	 */
+	struct tdh_vp_enter_out out;
+	/* Output Format #5 : On TDCALL(TDG.VP.VMCALL) */
+	struct tdh_vp_enter_tdcall tdcall;
+};
+
 /* SEAMCALL wrappers for creating/destroying/running TDX guests */
-u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args);
+u64 tdh_vp_enter(u64 tdvpr, union tdh_vp_enter_args *tdh_vp_enter_args);
 u64 tdh_mng_addcx(u64 tdr, u64 tdcs);
 u64 tdh_mem_page_add(u64 tdr, u64 gpa, u64 hpa, u64 source, u64 *rcx, u64 *rdx);
 u64 tdh_mem_sept_add(u64 tdr, u64 gpa, u64 level, u64 hpa, u64 *rcx, u64 *rdx);
diff --git a/arch/x86/kvm/vmx/tdx.c b/arch/x86/kvm/vmx/tdx.c
index f5fc1a782b5b..56af7b8c71ab 100644
--- a/arch/x86/kvm/vmx/tdx.c
+++ b/arch/x86/kvm/vmx/tdx.c
@@ -211,57 +211,41 @@ static bool tdx_check_exit_reason(struct kvm_vcpu *vcpu, u16 reason)
 
 static __always_inline unsigned long tdexit_exit_qual(struct kvm_vcpu *vcpu)
 {
-	return kvm_rcx_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.out.exit_qual;
 }
 
 static __always_inline unsigned long tdexit_ext_exit_qual(struct kvm_vcpu *vcpu)
 {
-	return kvm_rdx_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.out.ext_exit_qual;
 }
 
 static __always_inline unsigned long tdexit_gpa(struct kvm_vcpu *vcpu)
 {
-	return kvm_r8_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.out.gpa;
 }
 
 static __always_inline unsigned long tdexit_intr_info(struct kvm_vcpu *vcpu)
 {
-	return kvm_r9_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.out.interrupt_info;
 }
 
-#define BUILD_TDVMCALL_ACCESSORS(param, gpr)				\
-static __always_inline							\
-unsigned long tdvmcall_##param##_read(struct kvm_vcpu *vcpu)		\
-{									\
-	return kvm_##gpr##_read(vcpu);					\
-}									\
-static __always_inline void tdvmcall_##param##_write(struct kvm_vcpu *vcpu, \
-						     unsigned long val)  \
-{									\
-	kvm_##gpr##_write(vcpu, val);					\
-}
-BUILD_TDVMCALL_ACCESSORS(a0, r12);
-BUILD_TDVMCALL_ACCESSORS(a1, r13);
-BUILD_TDVMCALL_ACCESSORS(a2, r14);
-BUILD_TDVMCALL_ACCESSORS(a3, r15);
-
 static __always_inline unsigned long tdvmcall_exit_type(struct kvm_vcpu *vcpu)
 {
-	return kvm_r10_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.tdcall.fn;
 }
 static __always_inline unsigned long tdvmcall_leaf(struct kvm_vcpu *vcpu)
 {
-	return kvm_r11_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_args.tdcall.subfn;
 }
 static __always_inline void tdvmcall_set_return_code(struct kvm_vcpu *vcpu,
 						     long val)
 {
-	kvm_r10_write(vcpu, val);
+	to_tdx(vcpu)->vp_enter_args.in.ret_code = val;
 }
 static __always_inline void tdvmcall_set_return_val(struct kvm_vcpu *vcpu,
 						    unsigned long val)
 {
-	kvm_r11_write(vcpu, val);
+	to_tdx(vcpu)->vp_enter_args.in.read.value = val;
 }
 
 static inline void tdx_hkid_free(struct kvm_tdx *kvm_tdx)
@@ -745,7 +729,7 @@ bool tdx_interrupt_allowed(struct kvm_vcpu *vcpu)
 	    tdvmcall_exit_type(vcpu) || tdvmcall_leaf(vcpu) != EXIT_REASON_HLT)
 	    return true;
 
-	return !tdvmcall_a0_read(vcpu);
+	return !to_tdx(vcpu)->vp_enter_args.tdcall.hlt.intr_blocked_flag;
 }
 
 bool tdx_protected_apic_has_interrupt(struct kvm_vcpu *vcpu)
@@ -899,51 +883,10 @@ static void tdx_restore_host_xsave_state(struct kvm_vcpu *vcpu)
 static noinstr void tdx_vcpu_enter_exit(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_tdx *tdx = to_tdx(vcpu);
-	struct tdx_module_args args;
 
 	guest_state_enter_irqoff();
 
-	/*
-	 * TODO: optimization:
-	 * - Eliminate copy between args and vcpu->arch.regs.
-	 * - copyin/copyout registers only if (tdx->tdvmvall.regs_mask != 0)
-	 *   which means TDG.VP.VMCALL.
-	 */
-	args = (struct tdx_module_args) {
-		.rcx = tdx->tdvpr_pa,
-#define REG(reg, REG)	.reg = vcpu->arch.regs[VCPU_REGS_ ## REG]
-		REG(rdx, RDX),
-		REG(r8,  R8),
-		REG(r9,  R9),
-		REG(r10, R10),
-		REG(r11, R11),
-		REG(r12, R12),
-		REG(r13, R13),
-		REG(r14, R14),
-		REG(r15, R15),
-		REG(rbx, RBX),
-		REG(rdi, RDI),
-		REG(rsi, RSI),
-#undef REG
-	};
-
-	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &args);
-
-#define REG(reg, REG)	vcpu->arch.regs[VCPU_REGS_ ## REG] = args.reg
-	REG(rcx, RCX);
-	REG(rdx, RDX);
-	REG(r8,  R8);
-	REG(r9,  R9);
-	REG(r10, R10);
-	REG(r11, R11);
-	REG(r12, R12);
-	REG(r13, R13);
-	REG(r14, R14);
-	REG(r15, R15);
-	REG(rbx, RBX);
-	REG(rdi, RDI);
-	REG(rsi, RSI);
-#undef REG
+	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &tdx->vp_enter_args);
 
 	if (tdx_check_exit_reason(vcpu, EXIT_REASON_EXCEPTION_NMI) &&
 	    is_nmi(tdexit_intr_info(vcpu)))
@@ -1083,8 +1026,15 @@ static int complete_hypercall_exit(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	int r;
 
+	kvm_r10_write(vcpu, tdx->vp_enter_args.tdcall.fn);
+	kvm_r11_write(vcpu, tdx->vp_enter_args.tdcall.subfn);
+	kvm_r12_write(vcpu, tdx->vp_enter_args.tdcall.vmcall.p2);
+	kvm_r13_write(vcpu, tdx->vp_enter_args.tdcall.vmcall.p3);
+	kvm_r14_write(vcpu, tdx->vp_enter_args.tdcall.vmcall.p4);
+
 	/*
 	 * ABI for KVM tdvmcall argument:
 	 * In Guest-Hypervisor Communication Interface(GHCI) specification,
@@ -1092,13 +1042,12 @@ static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
 	 * vendor-specific.  KVM uses this for KVM hypercall.  NOTE: KVM
 	 * hypercall number starts from one.  Zero isn't used for KVM hypercall
 	 * number.
-	 *
-	 * R10: KVM hypercall number
-	 * arguments: R11, R12, R13, R14.
 	 */
 	r = __kvm_emulate_hypercall(vcpu, r10, r11, r12, r13, r14, true, 0,
 				    R10, complete_hypercall_exit);
 
+	tdvmcall_set_return_code(vcpu, kvm_r10_read(vcpu));
+
 	return r > 0;
 }
 
@@ -1116,7 +1065,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
 
 	if(vcpu->run->hypercall.ret) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
-		kvm_r11_write(vcpu, tdx->map_gpa_next);
+		tdx->vp_enter_args.in.mapgpa.failed_gpa = tdx->map_gpa_next;
 		return 1;
 	}
 
@@ -1137,7 +1086,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
 	if (pi_has_pending_interrupt(vcpu) ||
 	    kvm_test_request(KVM_REQ_NMI, vcpu) || vcpu->arch.nmi_pending) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_RETRY);
-		kvm_r11_write(vcpu, tdx->map_gpa_next);
+		tdx->vp_enter_args.in.mapgpa.failed_gpa = tdx->map_gpa_next;
 		return 1;
 	}
 
@@ -1169,8 +1118,8 @@ static void __tdx_map_gpa(struct vcpu_tdx * tdx)
 static int tdx_map_gpa(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_tdx * tdx = to_tdx(vcpu);
-	u64 gpa = tdvmcall_a0_read(vcpu);
-	u64 size = tdvmcall_a1_read(vcpu);
+	u64 gpa  = tdx->vp_enter_args.tdcall.mapgpa.gpa;
+	u64 size = tdx->vp_enter_args.tdcall.mapgpa.size;
 	u64 ret;
 
 	/*
@@ -1206,14 +1155,19 @@ static int tdx_map_gpa(struct kvm_vcpu *vcpu)
 
 error:
 	tdvmcall_set_return_code(vcpu, ret);
-	kvm_r11_write(vcpu, gpa);
+	tdx->vp_enter_args.in.mapgpa.failed_gpa = gpa;
 	return 1;
 }
 
 static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 {
-	u64 reg_mask = kvm_rcx_read(vcpu);
-	u64* opt_regs;
+	union tdh_vp_enter_args *args = &to_tdx(vcpu)->vp_enter_args;
+	__u64 *data = &vcpu->run->system_event.data[0];
+	u64 reg_mask = args->tdcall.reg_mask;
+	const int mask[] = {14, 15, 3, 7, 6};
+	int cnt = 0;
+
+	BUILD_BUG_ON(ARRAY_SIZE(mask) != TDX_ERR_DATA_PART_1);
 
 	/*
 	 * Skip sanity checks and let userspace decide what to do if sanity
@@ -1221,32 +1175,35 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 	 */
 	vcpu->run->exit_reason = KVM_EXIT_SYSTEM_EVENT;
 	vcpu->run->system_event.type = KVM_SYSTEM_EVENT_TDX_FATAL;
-	vcpu->run->system_event.ndata = 10;
 	/* Error codes. */
-	vcpu->run->system_event.data[0] = tdvmcall_a0_read(vcpu);
+	data[cnt++] = args->tdcall.reportfatalerror.err_codes;
 	/* GPA of additional information page. */
-	vcpu->run->system_event.data[1] = tdvmcall_a1_read(vcpu);
+	data[cnt++] = args->tdcall.reportfatalerror.err_data_gpa;
+
 	/* Information passed via registers (up to 64 bytes). */
-	opt_regs = &vcpu->run->system_event.data[2];
+	for (int i = 0; i < TDX_ERR_DATA_PART_1; i++) {
+		if (reg_mask & BIT_ULL(mask[i]))
+			data[cnt++] = args->tdcall.reportfatalerror.err_data[i];
+		else
+			data[cnt++] = 0;
+	}
 
-#define COPY_REG(REG, MASK)						\
-	do {								\
-		if (reg_mask & MASK)					\
-			*opt_regs = kvm_ ## REG ## _read(vcpu);		\
-		else							\
-			*opt_regs = 0;					\
-		opt_regs++;						\
-	} while (0)
+	if (reg_mask & BIT_ULL(8))
+		data[cnt++] = args->tdcall.data[1];
+	else
+		data[cnt++] = 0;
 
-	/* The order is defined in GHCI. */
-	COPY_REG(r14, BIT_ULL(14));
-	COPY_REG(r15, BIT_ULL(15));
-	COPY_REG(rbx, BIT_ULL(3));
-	COPY_REG(rdi, BIT_ULL(7));
-	COPY_REG(rsi, BIT_ULL(6));
-	COPY_REG(r8, BIT_ULL(8));
-	COPY_REG(r9, BIT_ULL(9));
-	COPY_REG(rdx, BIT_ULL(2));
+	if (reg_mask & BIT_ULL(9))
+		data[cnt++] = args->tdcall.data[2];
+	else
+		data[cnt++] = 0;
+
+	if (reg_mask & BIT_ULL(2))
+		data[cnt++] = args->tdcall.data[0];
+	else
+		data[cnt++] = 0;
+
+	vcpu->run->system_event.ndata = cnt;
 
 	/*
 	 * Set the status code according to GHCI spec, although the vCPU may
@@ -1260,18 +1217,18 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_cpuid(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	u32 eax, ebx, ecx, edx;
 
-	/* EAX and ECX for cpuid is stored in R12 and R13. */
-	eax = tdvmcall_a0_read(vcpu);
-	ecx = tdvmcall_a1_read(vcpu);
+	eax = tdx->vp_enter_args.tdcall.cpuid.eax;
+	ecx = tdx->vp_enter_args.tdcall.cpuid.ecx;
 
 	kvm_cpuid(vcpu, &eax, &ebx, &ecx, &edx, false);
 
-	tdvmcall_a0_write(vcpu, eax);
-	tdvmcall_a1_write(vcpu, ebx);
-	tdvmcall_a2_write(vcpu, ecx);
-	tdvmcall_a3_write(vcpu, edx);
+	tdx->vp_enter_args.in.cpuid.eax = eax;
+	tdx->vp_enter_args.in.cpuid.ebx = ebx;
+	tdx->vp_enter_args.in.cpuid.ecx = ecx;
+	tdx->vp_enter_args.in.cpuid.edx = edx;
 
 	tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
 
@@ -1312,6 +1269,7 @@ static int tdx_complete_pio_in(struct kvm_vcpu *vcpu)
 static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 {
 	struct x86_emulate_ctxt *ctxt = vcpu->arch.emulate_ctxt;
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	unsigned long val = 0;
 	unsigned int port;
 	int size, ret;
@@ -1319,9 +1277,9 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 
 	++vcpu->stat.io_exits;
 
-	size = tdvmcall_a0_read(vcpu);
-	write = tdvmcall_a1_read(vcpu);
-	port = tdvmcall_a2_read(vcpu);
+	size  = tdx->vp_enter_args.tdcall.io.size;
+	write = tdx->vp_enter_args.tdcall.io.direction;
+	port  = tdx->vp_enter_args.tdcall.io.port;
 
 	if (size != 1 && size != 2 && size != 4) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
@@ -1329,7 +1287,7 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 	}
 
 	if (write) {
-		val = tdvmcall_a3_read(vcpu);
+		val = tdx->vp_enter_args.tdcall.io.value;
 		ret = ctxt->ops->pio_out_emulated(ctxt, size, port, &val, 1);
 	} else {
 		ret = ctxt->ops->pio_in_emulated(ctxt, size, port, &val, 1);
@@ -1397,14 +1355,15 @@ static inline int tdx_mmio_read(struct kvm_vcpu *vcpu, gpa_t gpa, int size)
 
 static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	int size, write, r;
 	unsigned long val;
 	gpa_t gpa;
 
-	size = tdvmcall_a0_read(vcpu);
-	write = tdvmcall_a1_read(vcpu);
-	gpa = tdvmcall_a2_read(vcpu);
-	val = write ? tdvmcall_a3_read(vcpu) : 0;
+	size  = tdx->vp_enter_args.tdcall.mmio.size;
+	write = tdx->vp_enter_args.tdcall.mmio.direction;
+	gpa   = tdx->vp_enter_args.tdcall.mmio.mmio_addr;
+	val = write ? tdx->vp_enter_args.tdcall.mmio.value : 0;
 
 	if (size != 1 && size != 2 && size != 4 && size != 8)
 		goto error;
@@ -1456,7 +1415,7 @@ static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
 {
-	u32 index = tdvmcall_a0_read(vcpu);
+	u32 index = to_tdx(vcpu)->vp_enter_args.tdcall.rd.msr;
 	u64 data;
 
 	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_READ) ||
@@ -1474,8 +1433,8 @@ static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
 {
-	u32 index = tdvmcall_a0_read(vcpu);
-	u64 data = tdvmcall_a1_read(vcpu);
+	u32 index = to_tdx(vcpu)->vp_enter_args.tdcall.wr.msr;
+	u64 data  = to_tdx(vcpu)->vp_enter_args.tdcall.wr.value;
 
 	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_WRITE) ||
 	    kvm_set_msr(vcpu, index, data)) {
@@ -1491,14 +1450,16 @@ static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
 
 static int tdx_get_td_vm_call_info(struct kvm_vcpu *vcpu)
 {
-	if (tdvmcall_a0_read(vcpu))
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
+
+	if (tdx->vp_enter_args.tdcall.gettdvmcallinfo.leaf) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
-	else {
+	} else {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
-		kvm_r11_write(vcpu, 0);
-		tdvmcall_a0_write(vcpu, 0);
-		tdvmcall_a1_write(vcpu, 0);
-		tdvmcall_a2_write(vcpu, 0);
+		tdx->vp_enter_args.in.gettdvmcallinfo[0] = 0;
+		tdx->vp_enter_args.in.gettdvmcallinfo[1] = 0;
+		tdx->vp_enter_args.in.gettdvmcallinfo[2] = 0;
+		tdx->vp_enter_args.in.gettdvmcallinfo[3] = 0;
 	}
 	return 1;
 }
diff --git a/arch/x86/kvm/vmx/tdx.h b/arch/x86/kvm/vmx/tdx.h
index c9daf71d358a..a0d33b048b7e 100644
--- a/arch/x86/kvm/vmx/tdx.h
+++ b/arch/x86/kvm/vmx/tdx.h
@@ -71,6 +71,7 @@ struct vcpu_tdx {
 	struct list_head cpu_list;
 
 	u64 vp_enter_ret;
+	union tdh_vp_enter_args vp_enter_args;
 
 	enum vcpu_tdx_state state;
 
diff --git a/arch/x86/virt/vmx/tdx/tdx.c b/arch/x86/virt/vmx/tdx/tdx.c
index 16e0b598c4ec..d5c06c5eeaec 100644
--- a/arch/x86/virt/vmx/tdx/tdx.c
+++ b/arch/x86/virt/vmx/tdx/tdx.c
@@ -1600,8 +1600,10 @@ static inline u64 tdx_seamcall_sept(u64 op, struct tdx_module_args *in)
 	return ret;
 }
 
-noinstr u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
+noinstr u64 tdh_vp_enter(u64 tdvpr, union tdh_vp_enter_args *vp_enter_args)
 {
+	struct tdx_module_args *args = (struct tdx_module_args *)vp_enter_args;
+
 	args->rcx = tdvpr;
 
 	return __seamcall_saved_ret(TDH_VP_ENTER, args);
-- 
2.43.0
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Dave Hansen 1 year, 2 months ago
On 11/22/24 03:10, Adrian Hunter wrote:
> +struct tdh_vp_enter_tdcall {
> +	u64	reg_mask	: 32,
> +		vm_idx		:  2,
> +		reserved_0	: 30;
> +	u64	data[TDX_ERR_DATA_PART_2];
> +	u64	fn;	/* Non-zero for hypercalls, zero otherwise */
> +	u64	subfn;
> +	union {
> +		struct tdh_vp_enter_vmcall 		vmcall;
> +		struct tdh_vp_enter_gettdvmcallinfo	gettdvmcallinfo;
> +		struct tdh_vp_enter_mapgpa		mapgpa;
> +		struct tdh_vp_enter_getquote		getquote;
> +		struct tdh_vp_enter_reportfatalerror	reportfatalerror;
> +		struct tdh_vp_enter_cpuid		cpuid;
> +		struct tdh_vp_enter_mmio		mmio;
> +		struct tdh_vp_enter_hlt			hlt;
> +		struct tdh_vp_enter_io			io;
> +		struct tdh_vp_enter_rd			rd;
> +		struct tdh_vp_enter_wr			wr;
> +	};
> +};

Let's say someone declares this:

struct tdh_vp_enter_mmio {
	u64	size;
	u64	mmio_addr;
	u64	direction;
	u64	value;
};

How long is that going to take you to debug?
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
On 22/11/24 18:33, Dave Hansen wrote:
> On 11/22/24 03:10, Adrian Hunter wrote:
>> +struct tdh_vp_enter_tdcall {
>> +	u64	reg_mask	: 32,
>> +		vm_idx		:  2,
>> +		reserved_0	: 30;
>> +	u64	data[TDX_ERR_DATA_PART_2];
>> +	u64	fn;	/* Non-zero for hypercalls, zero otherwise */
>> +	u64	subfn;
>> +	union {
>> +		struct tdh_vp_enter_vmcall 		vmcall;
>> +		struct tdh_vp_enter_gettdvmcallinfo	gettdvmcallinfo;
>> +		struct tdh_vp_enter_mapgpa		mapgpa;
>> +		struct tdh_vp_enter_getquote		getquote;
>> +		struct tdh_vp_enter_reportfatalerror	reportfatalerror;
>> +		struct tdh_vp_enter_cpuid		cpuid;
>> +		struct tdh_vp_enter_mmio		mmio;
>> +		struct tdh_vp_enter_hlt			hlt;
>> +		struct tdh_vp_enter_io			io;
>> +		struct tdh_vp_enter_rd			rd;
>> +		struct tdh_vp_enter_wr			wr;
>> +	};
>> +};
> 
> Let's say someone declares this:
> 
> struct tdh_vp_enter_mmio {
> 	u64	size;
> 	u64	mmio_addr;
> 	u64	direction;
> 	u64	value;
> };
> 
> How long is that going to take you to debug?

When adding a new hardware definition, it would be sensible
to check the hardware definition first before checking anything
else.

However, to stop existing members being accidentally moved,
could add:

#define CHECK_OFFSETS_EQ(reg, member) \
	BUILD_BUG_ON(offsetof(struct tdx_module_args, reg) != offsetof(union tdh_vp_enter_args, member));

	CHECK_OFFSETS_EQ(r12, tdcall.mmio.size);
	CHECK_OFFSETS_EQ(r13, tdcall.mmio.direction);
	CHECK_OFFSETS_EQ(r14, tdcall.mmio.mmio_addr);
	CHECK_OFFSETS_EQ(r15, tdcall.mmio.value);
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
On 25/11/24 15:40, Adrian Hunter wrote:
> On 22/11/24 18:33, Dave Hansen wrote:
>> On 11/22/24 03:10, Adrian Hunter wrote:
>>> +struct tdh_vp_enter_tdcall {
>>> +	u64	reg_mask	: 32,
>>> +		vm_idx		:  2,
>>> +		reserved_0	: 30;
>>> +	u64	data[TDX_ERR_DATA_PART_2];
>>> +	u64	fn;	/* Non-zero for hypercalls, zero otherwise */
>>> +	u64	subfn;
>>> +	union {
>>> +		struct tdh_vp_enter_vmcall 		vmcall;
>>> +		struct tdh_vp_enter_gettdvmcallinfo	gettdvmcallinfo;
>>> +		struct tdh_vp_enter_mapgpa		mapgpa;
>>> +		struct tdh_vp_enter_getquote		getquote;
>>> +		struct tdh_vp_enter_reportfatalerror	reportfatalerror;
>>> +		struct tdh_vp_enter_cpuid		cpuid;
>>> +		struct tdh_vp_enter_mmio		mmio;
>>> +		struct tdh_vp_enter_hlt			hlt;
>>> +		struct tdh_vp_enter_io			io;
>>> +		struct tdh_vp_enter_rd			rd;
>>> +		struct tdh_vp_enter_wr			wr;
>>> +	};
>>> +};
>>
>> Let's say someone declares this:
>>
>> struct tdh_vp_enter_mmio {
>> 	u64	size;
>> 	u64	mmio_addr;
>> 	u64	direction;
>> 	u64	value;
>> };
>>
>> How long is that going to take you to debug?
> 
> When adding a new hardware definition, it would be sensible
> to check the hardware definition first before checking anything
> else.
> 
> However, to stop existing members being accidentally moved,
> could add:
> 
> #define CHECK_OFFSETS_EQ(reg, member) \
> 	BUILD_BUG_ON(offsetof(struct tdx_module_args, reg) != offsetof(union tdh_vp_enter_args, member));
> 
> 	CHECK_OFFSETS_EQ(r12, tdcall.mmio.size);
> 	CHECK_OFFSETS_EQ(r13, tdcall.mmio.direction);
> 	CHECK_OFFSETS_EQ(r14, tdcall.mmio.mmio_addr);
> 	CHECK_OFFSETS_EQ(r15, tdcall.mmio.value);
> 

Note, struct tdh_vp_enter_tdcall is an output format.  The tdcall
arguments come directly from the guest with no validation by the
TDX Module.  It could be rubbish, or even malicious rubbish.  The
exit handlers validate the values before using them.

WRT the TDCALL input format (response by the host VMM), 'ret_code'
and 'failed_gpa' could use types other than 'u64', but the other
members are really 'u64'.

/* TDH.VP.ENTER Input Format #2 : Following a previous TDCALL(TDG.VP.VMCALL) */
struct tdh_vp_enter_in {
	u64	__vcpu_handle_and_flags; /* Don't use. tdh_vp_enter() will take care of it */
	u64	unused[3];
	u64	ret_code;
	union {
		u64 gettdvmcallinfo[4];
		struct {
			u64	failed_gpa;
		} mapgpa;
		struct {
			u64	unused;
			u64	eax;
			u64	ebx;
			u64	ecx;
			u64	edx;
		} cpuid;
		/* Value read for IO, MMIO or RDMSR */
		struct {
			u64	value;
		} read;
	};
};

Another different alternative could be to use an opaque structure,
not visible to KVM, and then all accesses to it become helper
functions like:

struct tdx_args;

int tdx_args_get_mmio(struct tdx_args *args,
		      enum tdx_access_size *size,
		      enum tdx_access_dir *direction,
		      gpa_t *addr,
		      u64 *value);

void tdx_args_set_failed_gpa(struct tdx_args *args, gpa_t gpa);
void tdx_args_set_ret_code(struct tdx_args *args, enum tdx_ret_code ret_code);
etc

For the 'get' functions, that would tend to imply the helpers
would do some validation.
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 2 months ago
On 28/11/24 13:13, Adrian Hunter wrote:
> On 25/11/24 15:40, Adrian Hunter wrote:
>> On 22/11/24 18:33, Dave Hansen wrote:
>>> On 11/22/24 03:10, Adrian Hunter wrote:
>>>> +struct tdh_vp_enter_tdcall {
>>>> +	u64	reg_mask	: 32,
>>>> +		vm_idx		:  2,
>>>> +		reserved_0	: 30;
>>>> +	u64	data[TDX_ERR_DATA_PART_2];
>>>> +	u64	fn;	/* Non-zero for hypercalls, zero otherwise */
>>>> +	u64	subfn;
>>>> +	union {
>>>> +		struct tdh_vp_enter_vmcall 		vmcall;
>>>> +		struct tdh_vp_enter_gettdvmcallinfo	gettdvmcallinfo;
>>>> +		struct tdh_vp_enter_mapgpa		mapgpa;
>>>> +		struct tdh_vp_enter_getquote		getquote;
>>>> +		struct tdh_vp_enter_reportfatalerror	reportfatalerror;
>>>> +		struct tdh_vp_enter_cpuid		cpuid;
>>>> +		struct tdh_vp_enter_mmio		mmio;
>>>> +		struct tdh_vp_enter_hlt			hlt;
>>>> +		struct tdh_vp_enter_io			io;
>>>> +		struct tdh_vp_enter_rd			rd;
>>>> +		struct tdh_vp_enter_wr			wr;
>>>> +	};
>>>> +};
>>>
>>> Let's say someone declares this:
>>>
>>> struct tdh_vp_enter_mmio {
>>> 	u64	size;
>>> 	u64	mmio_addr;
>>> 	u64	direction;
>>> 	u64	value;
>>> };
>>>
>>> How long is that going to take you to debug?
>>
>> When adding a new hardware definition, it would be sensible
>> to check the hardware definition first before checking anything
>> else.
>>
>> However, to stop existing members being accidentally moved,
>> could add:
>>
>> #define CHECK_OFFSETS_EQ(reg, member) \
>> 	BUILD_BUG_ON(offsetof(struct tdx_module_args, reg) != offsetof(union tdh_vp_enter_args, member));
>>
>> 	CHECK_OFFSETS_EQ(r12, tdcall.mmio.size);
>> 	CHECK_OFFSETS_EQ(r13, tdcall.mmio.direction);
>> 	CHECK_OFFSETS_EQ(r14, tdcall.mmio.mmio_addr);
>> 	CHECK_OFFSETS_EQ(r15, tdcall.mmio.value);
>>
> 
> Note, struct tdh_vp_enter_tdcall is an output format.  The tdcall
> arguments come directly from the guest with no validation by the
> TDX Module.  It could be rubbish, or even malicious rubbish.  The
> exit handlers validate the values before using them.
> 
> WRT the TDCALL input format (response by the host VMM), 'ret_code'
> and 'failed_gpa' could use types other than 'u64', but the other
> members are really 'u64'.
> 
> /* TDH.VP.ENTER Input Format #2 : Following a previous TDCALL(TDG.VP.VMCALL) */
> struct tdh_vp_enter_in {
> 	u64	__vcpu_handle_and_flags; /* Don't use. tdh_vp_enter() will take care of it */
> 	u64	unused[3];
> 	u64	ret_code;
> 	union {
> 		u64 gettdvmcallinfo[4];
> 		struct {
> 			u64	failed_gpa;
> 		} mapgpa;
> 		struct {
> 			u64	unused;
> 			u64	eax;
> 			u64	ebx;
> 			u64	ecx;
> 			u64	edx;
> 		} cpuid;
> 		/* Value read for IO, MMIO or RDMSR */
> 		struct {
> 			u64	value;
> 		} read;
> 	};
> };
> 
> Another different alternative could be to use an opaque structure,
> not visible to KVM, and then all accesses to it become helper
> functions like:
> 
> struct tdx_args;
> 
> int tdx_args_get_mmio(struct tdx_args *args,
> 		      enum tdx_access_size *size,
> 		      enum tdx_access_dir *direction,
> 		      gpa_t *addr,
> 		      u64 *value);
> 
> void tdx_args_set_failed_gpa(struct tdx_args *args, gpa_t gpa);
> void tdx_args_set_ret_code(struct tdx_args *args, enum tdx_ret_code ret_code);
> etc
> 
> For the 'get' functions, that would tend to imply the helpers
> would do some validation.
> 

IIRC Dave said something like, if the wrapper doesn't add any
value, then it is just as well not to have it at all.

So that option would be to drop patch "x86/virt/tdx: Add SEAMCALL
wrapper to enter/exit TDX guest" with tdh_vp_enter() and instead
just call __seamcall_saved_ret() directly, noting that:

 - __seamcall_saved_ret() is only used for TDH.VP.ENTER
 - KVM seems likely to be the only code that would ever
 need to use TDH.VP.ENTER
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 1 month ago
The diff below shows another alternative.  This time using
structs not a union.  The structs are easier to read than
the union, and require copying arguments, which also allows
using types that have sizes other than a GPR's (u64) size.

diff --git a/arch/x86/include/asm/shared/tdx.h b/arch/x86/include/asm/shared/tdx.h
index 192ae798b214..85f87d90ac89 100644
--- a/arch/x86/include/asm/shared/tdx.h
+++ b/arch/x86/include/asm/shared/tdx.h
@@ -21,20 +21,6 @@
 /* TDCS fields. To be used by TDG.VM.WR and TDG.VM.RD module calls */
 #define TDCS_NOTIFY_ENABLES		0x9100000000000010
 
-/* TDX hypercall Leaf IDs */
-#define TDVMCALL_GET_TD_VM_CALL_INFO	0x10000
-#define TDVMCALL_MAP_GPA		0x10001
-#define TDVMCALL_GET_QUOTE		0x10002
-#define TDVMCALL_REPORT_FATAL_ERROR	0x10003
-
-/*
- * TDG.VP.VMCALL Status Codes (returned in R10)
- */
-#define TDVMCALL_STATUS_SUCCESS		0x0000000000000000ULL
-#define TDVMCALL_STATUS_RETRY		0x0000000000000001ULL
-#define TDVMCALL_STATUS_INVALID_OPERAND	0x8000000000000000ULL
-#define TDVMCALL_STATUS_ALIGN_ERROR	0x8000000000000002ULL
-
 /*
  * Bitmasks of exposed registers (with VMM).
  */
diff --git a/arch/x86/include/asm/tdx.h b/arch/x86/include/asm/tdx.h
index 01409a59224d..e4a45378a84b 100644
--- a/arch/x86/include/asm/tdx.h
+++ b/arch/x86/include/asm/tdx.h
@@ -33,6 +33,7 @@
 
 #ifndef __ASSEMBLY__
 
+#include <linux/kvm_types.h>
 #include <uapi/asm/mce.h>
 #include "tdx_global_metadata.h"
 
@@ -96,6 +97,7 @@ u64 __seamcall_saved_ret(u64 fn, struct tdx_module_args *args);
 void tdx_init(void);
 
 #include <asm/archrandom.h>
+#include <asm/vmx.h>
 
 typedef u64 (*sc_func_t)(u64 fn, struct tdx_module_args *args);
 
@@ -123,8 +125,122 @@ const struct tdx_sys_info *tdx_get_sysinfo(void);
 int tdx_guest_keyid_alloc(void);
 void tdx_guest_keyid_free(unsigned int keyid);
 
+/* TDG.VP.VMCALL Sub-function */
+enum tdvmcall_subfn {
+	TDVMCALL_NONE			= -1, /* Not a TDG.VP.VMCALL */
+	TDVMCALL_GET_TD_VM_CALL_INFO	= 0x10000,
+	TDVMCALL_MAP_GPA		= 0x10001,
+	TDVMCALL_GET_QUOTE		= 0x10002,
+	TDVMCALL_REPORT_FATAL_ERROR	= 0x10003,
+	TDVMCALL_CPUID			= EXIT_REASON_CPUID,
+	TDVMCALL_HLT			= EXIT_REASON_HLT,
+	TDVMCALL_IO			= EXIT_REASON_IO_INSTRUCTION,
+	TDVMCALL_RDMSR			= EXIT_REASON_MSR_READ,
+	TDVMCALL_WRMSR			= EXIT_REASON_MSR_WRITE,
+	TDVMCALL_MMIO			= EXIT_REASON_EPT_VIOLATION,
+};
+
+enum tdx_io_direction {
+	TDX_READ,
+	TDX_WRITE
+};
+
+/* TDG.VP.VMCALL Sub-function Completion Status Codes */
+enum tdvmcall_status {
+	TDVMCALL_STATUS_SUCCESS		= 0x0000000000000000ULL,
+	TDVMCALL_STATUS_RETRY		= 0x0000000000000001ULL,
+	TDVMCALL_STATUS_INVALID_OPERAND	= 0x8000000000000000ULL,
+	TDVMCALL_STATUS_ALIGN_ERROR	= 0x8000000000000002ULL,
+};
+
+struct tdh_vp_enter_in {
+	/* TDG.VP.VMCALL common */
+	enum tdvmcall_status	ret_code;
+
+	/* TDG.VP.VMCALL Sub-function return information */
+
+	/* TDVMCALL_GET_TD_VM_CALL_INFO */
+	u64			gettdvmcallinfo[4];
+
+	/* TDVMCALL_MAP_GPA */
+	gpa_t			failed_gpa;
+
+	/* TDVMCALL_CPUID */
+	u32			eax;
+	u32			ebx;
+	u32			ecx;
+	u32			edx;
+
+	/* TDVMCALL_IO (read), TDVMCALL_RDMSR or TDVMCALL_MMIO (read) */
+	u64			value_read;
+};
+
+#define TDX_ERR_DATA_SZ 8
+
+struct tdh_vp_enter_out {
+	u64			exit_qual;
+	u32			intr_info;
+	u64			ext_exit_qual;
+	gpa_t			gpa;
+
+	/* TDG.VP.VMCALL common */
+	u32			reg_mask;
+	u64			fn;		/* Non-zero for KVM hypercalls, zero otherwise */
+	enum tdvmcall_subfn	subfn;
+
+	/* TDG.VP.VMCALL Sub-function arguments */
+
+	/* KVM hypercall */
+	u64			nr;
+	u64			p1;
+	u64			p2;
+	u64			p3;
+	u64			p4;
+
+	/* TDVMCALL_GET_TD_VM_CALL_INFO */
+	u64			leaf;
+
+	/* TDVMCALL_MAP_GPA */
+	gpa_t			map_gpa;
+	u64			map_gpa_size;
+
+	/* TDVMCALL_GET_QUOTE */
+	gpa_t			shared_gpa;
+	u64			shared_gpa_size;
+
+	/* TDVMCALL_REPORT_FATAL_ERROR */
+	u64			err_codes;
+	gpa_t			err_data_gpa;
+	u64			err_data[TDX_ERR_DATA_SZ];
+
+	/* TDVMCALL_CPUID */
+	u32			cpuid_leaf;
+	u32			cpuid_subleaf;
+
+	/* TDVMCALL_MMIO */
+	int			mmio_size;
+	enum tdx_io_direction	mmio_direction;
+	gpa_t			mmio_addr;
+	u32			mmio_value;
+
+	/* TDVMCALL_HLT */
+	bool			intr_blocked_flag;
+
+	/* TDVMCALL_IO_INSTRUCTION */
+	int			io_size;
+	enum tdx_io_direction	io_direction;
+	u16			io_port;
+	u32			io_value;
+
+	/* TDVMCALL_MSR_READ or TDVMCALL_MSR_WRITE */
+	u32			msr;
+
+	/* TDVMCALL_MSR_WRITE */
+	u64			write_value;
+};
+
 /* SEAMCALL wrappers for creating/destroying/running TDX guests */
-u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args);
+u64 tdh_vp_enter(u64 tdvpr, const struct tdh_vp_enter_in *in, struct tdh_vp_enter_out *out);
 u64 tdh_mng_addcx(u64 tdr, u64 tdcs);
 u64 tdh_mem_page_add(u64 tdr, u64 gpa, u64 hpa, u64 source, u64 *rcx, u64 *rdx);
 u64 tdh_mem_sept_add(u64 tdr, u64 gpa, u64 level, u64 hpa, u64 *rcx, u64 *rdx);
diff --git a/arch/x86/kvm/vmx/tdx.c b/arch/x86/kvm/vmx/tdx.c
index 218801618e9a..a8283a03fdd4 100644
--- a/arch/x86/kvm/vmx/tdx.c
+++ b/arch/x86/kvm/vmx/tdx.c
@@ -256,57 +256,41 @@ static __always_inline bool tdx_check_exit_reason(struct kvm_vcpu *vcpu, u16 rea
 
 static __always_inline unsigned long tdexit_exit_qual(struct kvm_vcpu *vcpu)
 {
-	return kvm_rcx_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.exit_qual;
 }
 
 static __always_inline unsigned long tdexit_ext_exit_qual(struct kvm_vcpu *vcpu)
 {
-	return kvm_rdx_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.ext_exit_qual;
 }
 
-static __always_inline unsigned long tdexit_gpa(struct kvm_vcpu *vcpu)
+static __always_inline gpa_t tdexit_gpa(struct kvm_vcpu *vcpu)
 {
-	return kvm_r8_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.gpa;
 }
 
 static __always_inline unsigned long tdexit_intr_info(struct kvm_vcpu *vcpu)
 {
-	return kvm_r9_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.intr_info;
 }
 
-#define BUILD_TDVMCALL_ACCESSORS(param, gpr)				\
-static __always_inline							\
-unsigned long tdvmcall_##param##_read(struct kvm_vcpu *vcpu)		\
-{									\
-	return kvm_##gpr##_read(vcpu);					\
-}									\
-static __always_inline void tdvmcall_##param##_write(struct kvm_vcpu *vcpu, \
-						     unsigned long val)  \
-{									\
-	kvm_##gpr##_write(vcpu, val);					\
-}
-BUILD_TDVMCALL_ACCESSORS(a0, r12);
-BUILD_TDVMCALL_ACCESSORS(a1, r13);
-BUILD_TDVMCALL_ACCESSORS(a2, r14);
-BUILD_TDVMCALL_ACCESSORS(a3, r15);
-
-static __always_inline unsigned long tdvmcall_exit_type(struct kvm_vcpu *vcpu)
+static __always_inline unsigned long tdvmcall_fn(struct kvm_vcpu *vcpu)
 {
-	return kvm_r10_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.fn;
 }
-static __always_inline unsigned long tdvmcall_leaf(struct kvm_vcpu *vcpu)
+static __always_inline enum tdvmcall_subfn tdvmcall_subfn(struct kvm_vcpu *vcpu)
 {
-	return kvm_r11_read(vcpu);
+	return to_tdx(vcpu)->vp_enter_out.subfn;
 }
 static __always_inline void tdvmcall_set_return_code(struct kvm_vcpu *vcpu,
-						     long val)
+						     enum tdvmcall_status val)
 {
-	kvm_r10_write(vcpu, val);
+	to_tdx(vcpu)->vp_enter_in.ret_code = val;
 }
 static __always_inline void tdvmcall_set_return_val(struct kvm_vcpu *vcpu,
 						    unsigned long val)
 {
-	kvm_r11_write(vcpu, val);
+	to_tdx(vcpu)->vp_enter_in.value_read = val;
 }
 
 static inline void tdx_hkid_free(struct kvm_tdx *kvm_tdx)
@@ -786,10 +770,10 @@ bool tdx_interrupt_allowed(struct kvm_vcpu *vcpu)
 	 * passes the interrupt block flag.
 	 */
 	if (!tdx_check_exit_reason(vcpu, EXIT_REASON_TDCALL) ||
-	    tdvmcall_exit_type(vcpu) || tdvmcall_leaf(vcpu) != EXIT_REASON_HLT)
+	    tdvmcall_fn(vcpu) || tdvmcall_subfn(vcpu) != TDVMCALL_HLT)
 	    return true;
 
-	return !tdvmcall_a0_read(vcpu);
+	return !to_tdx(vcpu)->vp_enter_out.intr_blocked_flag;
 }
 
 bool tdx_protected_apic_has_interrupt(struct kvm_vcpu *vcpu)
@@ -945,51 +929,10 @@ static void tdx_restore_host_xsave_state(struct kvm_vcpu *vcpu)
 static noinstr void tdx_vcpu_enter_exit(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_tdx *tdx = to_tdx(vcpu);
-	struct tdx_module_args args;
 
 	guest_state_enter_irqoff();
 
-	/*
-	 * TODO: optimization:
-	 * - Eliminate copy between args and vcpu->arch.regs.
-	 * - copyin/copyout registers only if (tdx->tdvmvall.regs_mask != 0)
-	 *   which means TDG.VP.VMCALL.
-	 */
-	args = (struct tdx_module_args) {
-		.rcx = tdx->tdvpr_pa,
-#define REG(reg, REG)	.reg = vcpu->arch.regs[VCPU_REGS_ ## REG]
-		REG(rdx, RDX),
-		REG(r8,  R8),
-		REG(r9,  R9),
-		REG(r10, R10),
-		REG(r11, R11),
-		REG(r12, R12),
-		REG(r13, R13),
-		REG(r14, R14),
-		REG(r15, R15),
-		REG(rbx, RBX),
-		REG(rdi, RDI),
-		REG(rsi, RSI),
-#undef REG
-	};
-
-	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &args);
-
-#define REG(reg, REG)	vcpu->arch.regs[VCPU_REGS_ ## REG] = args.reg
-	REG(rcx, RCX);
-	REG(rdx, RDX);
-	REG(r8,  R8);
-	REG(r9,  R9);
-	REG(r10, R10);
-	REG(r11, R11);
-	REG(r12, R12);
-	REG(r13, R13);
-	REG(r14, R14);
-	REG(r15, R15);
-	REG(rbx, RBX);
-	REG(rdi, RDI);
-	REG(rsi, RSI);
-#undef REG
+	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &tdx->vp_enter_in, &tdx->vp_enter_out);
 
 	if (tdx_check_exit_reason(vcpu, EXIT_REASON_EXCEPTION_NMI) &&
 	    is_nmi(tdexit_intr_info(vcpu)))
@@ -1128,8 +1071,15 @@ static int complete_hypercall_exit(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	int r;
 
+	kvm_r10_write(vcpu, tdx->vp_enter_out.nr);
+	kvm_r11_write(vcpu, tdx->vp_enter_out.p1);
+	kvm_r12_write(vcpu, tdx->vp_enter_out.p2);
+	kvm_r13_write(vcpu, tdx->vp_enter_out.p3);
+	kvm_r14_write(vcpu, tdx->vp_enter_out.p4);
+
 	/*
 	 * ABI for KVM tdvmcall argument:
 	 * In Guest-Hypervisor Communication Interface(GHCI) specification,
@@ -1137,13 +1087,12 @@ static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
 	 * vendor-specific.  KVM uses this for KVM hypercall.  NOTE: KVM
 	 * hypercall number starts from one.  Zero isn't used for KVM hypercall
 	 * number.
-	 *
-	 * R10: KVM hypercall number
-	 * arguments: R11, R12, R13, R14.
 	 */
 	r = __kvm_emulate_hypercall(vcpu, r10, r11, r12, r13, r14, true, 0,
 				    complete_hypercall_exit);
 
+	tdvmcall_set_return_code(vcpu, kvm_r10_read(vcpu));
+
 	return r > 0;
 }
 
@@ -1161,7 +1110,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
 
 	if(vcpu->run->hypercall.ret) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
-		kvm_r11_write(vcpu, tdx->map_gpa_next);
+		tdx->vp_enter_in.failed_gpa = tdx->map_gpa_next;
 		return 1;
 	}
 
@@ -1182,7 +1131,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
 	if (pi_has_pending_interrupt(vcpu) ||
 	    kvm_test_request(KVM_REQ_NMI, vcpu) || vcpu->arch.nmi_pending) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_RETRY);
-		kvm_r11_write(vcpu, tdx->map_gpa_next);
+		tdx->vp_enter_in.failed_gpa = tdx->map_gpa_next;
 		return 1;
 	}
 
@@ -1214,8 +1163,8 @@ static void __tdx_map_gpa(struct vcpu_tdx * tdx)
 static int tdx_map_gpa(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_tdx * tdx = to_tdx(vcpu);
-	u64 gpa = tdvmcall_a0_read(vcpu);
-	u64 size = tdvmcall_a1_read(vcpu);
+	u64 gpa  = tdx->vp_enter_out.map_gpa;
+	u64 size = tdx->vp_enter_out.map_gpa_size;
 	u64 ret;
 
 	/*
@@ -1251,14 +1200,17 @@ static int tdx_map_gpa(struct kvm_vcpu *vcpu)
 
 error:
 	tdvmcall_set_return_code(vcpu, ret);
-	kvm_r11_write(vcpu, gpa);
+	tdx->vp_enter_in.failed_gpa = gpa;
 	return 1;
 }
 
 static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 {
-	u64 reg_mask = kvm_rcx_read(vcpu);
-	u64* opt_regs;
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
+	__u64 *data = &vcpu->run->system_event.data[0];
+	u64 reg_mask = tdx->vp_enter_out.reg_mask;
+	const int mask[] = {14, 15, 3, 7, 6, 8, 9, 2};
+	int cnt = 0;
 
 	/*
 	 * Skip sanity checks and let userspace decide what to do if sanity
@@ -1266,32 +1218,20 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 	 */
 	vcpu->run->exit_reason = KVM_EXIT_SYSTEM_EVENT;
 	vcpu->run->system_event.type = KVM_SYSTEM_EVENT_TDX_FATAL;
-	vcpu->run->system_event.ndata = 10;
 	/* Error codes. */
-	vcpu->run->system_event.data[0] = tdvmcall_a0_read(vcpu);
+	data[cnt++] = tdx->vp_enter_out.err_codes;
 	/* GPA of additional information page. */
-	vcpu->run->system_event.data[1] = tdvmcall_a1_read(vcpu);
+	data[cnt++] = tdx->vp_enter_out.err_data_gpa;
+
 	/* Information passed via registers (up to 64 bytes). */
-	opt_regs = &vcpu->run->system_event.data[2];
+	for (int i = 0; i < TDX_ERR_DATA_SZ; i++) {
+		if (reg_mask & BIT_ULL(mask[i]))
+			data[cnt++] = tdx->vp_enter_out.err_data[i];
+		else
+			data[cnt++] = 0;
+	}
 
-#define COPY_REG(REG, MASK)						\
-	do {								\
-		if (reg_mask & MASK)					\
-			*opt_regs = kvm_ ## REG ## _read(vcpu);		\
-		else							\
-			*opt_regs = 0;					\
-		opt_regs++;						\
-	} while (0)
-
-	/* The order is defined in GHCI. */
-	COPY_REG(r14, BIT_ULL(14));
-	COPY_REG(r15, BIT_ULL(15));
-	COPY_REG(rbx, BIT_ULL(3));
-	COPY_REG(rdi, BIT_ULL(7));
-	COPY_REG(rsi, BIT_ULL(6));
-	COPY_REG(r8, BIT_ULL(8));
-	COPY_REG(r9, BIT_ULL(9));
-	COPY_REG(rdx, BIT_ULL(2));
+	vcpu->run->system_event.ndata = cnt;
 
 	/*
 	 * Set the status code according to GHCI spec, although the vCPU may
@@ -1305,18 +1245,18 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_cpuid(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	u32 eax, ebx, ecx, edx;
 
-	/* EAX and ECX for cpuid is stored in R12 and R13. */
-	eax = tdvmcall_a0_read(vcpu);
-	ecx = tdvmcall_a1_read(vcpu);
+	eax = tdx->vp_enter_out.cpuid_leaf;
+	ecx = tdx->vp_enter_out.cpuid_subleaf;
 
 	kvm_cpuid(vcpu, &eax, &ebx, &ecx, &edx, false);
 
-	tdvmcall_a0_write(vcpu, eax);
-	tdvmcall_a1_write(vcpu, ebx);
-	tdvmcall_a2_write(vcpu, ecx);
-	tdvmcall_a3_write(vcpu, edx);
+	tdx->vp_enter_in.eax = eax;
+	tdx->vp_enter_in.ebx = ebx;
+	tdx->vp_enter_in.ecx = ecx;
+	tdx->vp_enter_in.edx = edx;
 
 	tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
 
@@ -1356,6 +1296,7 @@ static int tdx_complete_pio_in(struct kvm_vcpu *vcpu)
 static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 {
 	struct x86_emulate_ctxt *ctxt = vcpu->arch.emulate_ctxt;
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	unsigned long val = 0;
 	unsigned int port;
 	int size, ret;
@@ -1363,9 +1304,9 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 
 	++vcpu->stat.io_exits;
 
-	size = tdvmcall_a0_read(vcpu);
-	write = tdvmcall_a1_read(vcpu);
-	port = tdvmcall_a2_read(vcpu);
+	size  = tdx->vp_enter_out.io_size;
+	write = tdx->vp_enter_out.io_direction == TDX_WRITE;
+	port  = tdx->vp_enter_out.io_port;
 
 	if (size != 1 && size != 2 && size != 4) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
@@ -1373,7 +1314,7 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
 	}
 
 	if (write) {
-		val = tdvmcall_a3_read(vcpu);
+		val = tdx->vp_enter_out.io_value;
 		ret = ctxt->ops->pio_out_emulated(ctxt, size, port, &val, 1);
 	} else {
 		ret = ctxt->ops->pio_in_emulated(ctxt, size, port, &val, 1);
@@ -1443,14 +1384,15 @@ static inline int tdx_mmio_read(struct kvm_vcpu *vcpu, gpa_t gpa, int size)
 
 static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
 	int size, write, r;
 	unsigned long val;
 	gpa_t gpa;
 
-	size = tdvmcall_a0_read(vcpu);
-	write = tdvmcall_a1_read(vcpu);
-	gpa = tdvmcall_a2_read(vcpu);
-	val = write ? tdvmcall_a3_read(vcpu) : 0;
+	size  = tdx->vp_enter_out.mmio_size;
+	write = tdx->vp_enter_out.mmio_direction == TDX_WRITE;
+	gpa   = tdx->vp_enter_out.mmio_addr;
+	val = write ? tdx->vp_enter_out.mmio_value : 0;
 
 	if (size != 1 && size != 2 && size != 4 && size != 8)
 		goto error;
@@ -1502,7 +1444,7 @@ static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
 {
-	u32 index = tdvmcall_a0_read(vcpu);
+	u32 index = to_tdx(vcpu)->vp_enter_out.msr;
 	u64 data;
 
 	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_READ) ||
@@ -1520,8 +1462,8 @@ static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
 
 static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
 {
-	u32 index = tdvmcall_a0_read(vcpu);
-	u64 data = tdvmcall_a1_read(vcpu);
+	u32 index = to_tdx(vcpu)->vp_enter_out.msr;
+	u64 data  = to_tdx(vcpu)->vp_enter_out.write_value;
 
 	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_WRITE) ||
 	    kvm_set_msr(vcpu, index, data)) {
@@ -1537,39 +1479,41 @@ static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
 
 static int tdx_get_td_vm_call_info(struct kvm_vcpu *vcpu)
 {
-	if (tdvmcall_a0_read(vcpu))
+	struct vcpu_tdx *tdx = to_tdx(vcpu);
+
+	if (tdx->vp_enter_out.leaf) {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
-	else {
+	} else {
 		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
-		kvm_r11_write(vcpu, 0);
-		tdvmcall_a0_write(vcpu, 0);
-		tdvmcall_a1_write(vcpu, 0);
-		tdvmcall_a2_write(vcpu, 0);
+		tdx->vp_enter_in.gettdvmcallinfo[0] = 0;
+		tdx->vp_enter_in.gettdvmcallinfo[1] = 0;
+		tdx->vp_enter_in.gettdvmcallinfo[2] = 0;
+		tdx->vp_enter_in.gettdvmcallinfo[3] = 0;
 	}
 	return 1;
 }
 
 static int handle_tdvmcall(struct kvm_vcpu *vcpu)
 {
-	if (tdvmcall_exit_type(vcpu))
+	if (tdvmcall_fn(vcpu))
 		return tdx_emulate_vmcall(vcpu);
 
-	switch (tdvmcall_leaf(vcpu)) {
+	switch (tdvmcall_subfn(vcpu)) {
 	case TDVMCALL_MAP_GPA:
 		return tdx_map_gpa(vcpu);
 	case TDVMCALL_REPORT_FATAL_ERROR:
 		return tdx_report_fatal_error(vcpu);
-	case EXIT_REASON_CPUID:
+	case TDVMCALL_CPUID:
 		return tdx_emulate_cpuid(vcpu);
-	case EXIT_REASON_HLT:
+	case TDVMCALL_HLT:
 		return tdx_emulate_hlt(vcpu);
-	case EXIT_REASON_IO_INSTRUCTION:
+	case TDVMCALL_IO:
 		return tdx_emulate_io(vcpu);
-	case EXIT_REASON_EPT_VIOLATION:
+	case TDVMCALL_MMIO:
 		return tdx_emulate_mmio(vcpu);
-	case EXIT_REASON_MSR_READ:
+	case TDVMCALL_RDMSR:
 		return tdx_emulate_rdmsr(vcpu);
-	case EXIT_REASON_MSR_WRITE:
+	case TDVMCALL_WRMSR:
 		return tdx_emulate_wrmsr(vcpu);
 	case TDVMCALL_GET_TD_VM_CALL_INFO:
 		return tdx_get_td_vm_call_info(vcpu);
diff --git a/arch/x86/kvm/vmx/tdx.h b/arch/x86/kvm/vmx/tdx.h
index 008180c0c30f..63d8b3359b10 100644
--- a/arch/x86/kvm/vmx/tdx.h
+++ b/arch/x86/kvm/vmx/tdx.h
@@ -69,6 +69,8 @@ struct vcpu_tdx {
 	struct list_head cpu_list;
 
 	u64 vp_enter_ret;
+	struct tdh_vp_enter_in vp_enter_in;
+	struct tdh_vp_enter_out vp_enter_out;
 
 	enum vcpu_tdx_state state;
 
diff --git a/arch/x86/virt/vmx/tdx/tdx.c b/arch/x86/virt/vmx/tdx/tdx.c
index 16e0b598c4ec..895d9ea4aeba 100644
--- a/arch/x86/virt/vmx/tdx/tdx.c
+++ b/arch/x86/virt/vmx/tdx/tdx.c
@@ -33,6 +33,7 @@
 #include <asm/msr-index.h>
 #include <asm/msr.h>
 #include <asm/cpufeature.h>
+#include <asm/vmx.h>
 #include <asm/tdx.h>
 #include <asm/cpu_device_id.h>
 #include <asm/processor.h>
@@ -1600,11 +1601,122 @@ static inline u64 tdx_seamcall_sept(u64 op, struct tdx_module_args *in)
 	return ret;
 }
 
-noinstr u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
+noinstr u64 tdh_vp_enter(u64 tdvpr, const struct tdh_vp_enter_in *in, struct tdh_vp_enter_out *out)
 {
-	args->rcx = tdvpr;
+	struct tdx_module_args args = {
+		.rcx = tdvpr,
+		.r10 = in->ret_code,
+	};
+	u64 ret;
 
-	return __seamcall_saved_ret(TDH_VP_ENTER, args);
+	/* If previous exit was TDG.VP.VMCALL */
+	switch (out->subfn) {
+	case TDVMCALL_GET_TD_VM_CALL_INFO:
+		args.r11 = in->gettdvmcallinfo[0];
+		args.r12 = in->gettdvmcallinfo[1];
+		args.r13 = in->gettdvmcallinfo[2];
+		args.r14 = in->gettdvmcallinfo[3];
+		break;
+	case TDVMCALL_MAP_GPA:
+		args.r11 = in->failed_gpa;
+		break;
+	case TDVMCALL_CPUID:
+		args.r12 = in->eax;
+		args.r13 = in->ebx;
+		args.r14 = in->ecx;
+		args.r15 = in->edx;
+		break;
+	case TDVMCALL_IO:
+	case TDVMCALL_RDMSR:
+	case TDVMCALL_MMIO:
+		args.r11 = in->value_read;
+		break;
+	case TDVMCALL_NONE:
+	case TDVMCALL_GET_QUOTE:
+	case TDVMCALL_REPORT_FATAL_ERROR:
+	case TDVMCALL_HLT:
+	case TDVMCALL_WRMSR:
+		break;
+	}
+
+	ret = __seamcall_saved_ret(TDH_VP_ENTER, &args);
+
+	if ((u16)ret == EXIT_REASON_TDCALL) {
+		out->reg_mask		= args.rcx;
+		out->fn = args.r10;
+		if (out->fn) {
+			out->nr		= args.r10;
+			out->p1		= args.r11;
+			out->p2		= args.r12;
+			out->p3		= args.r13;
+			out->p4		= args.r14;
+			out->subfn	= TDVMCALL_NONE;
+		} else {
+			out->subfn	= args.r11;
+		}
+	} else {
+		out->exit_qual		= args.rcx;
+		out->ext_exit_qual	= args.rdx;
+		out->gpa		= args.r8;
+		out->intr_info		= args.r9;
+		out->subfn		= TDVMCALL_NONE;
+	}
+
+	switch (out->subfn) {
+	case TDVMCALL_GET_TD_VM_CALL_INFO:
+		out->leaf		= args.r12;
+		break;
+	case TDVMCALL_MAP_GPA:
+		out->map_gpa		= args.r12;
+		out->map_gpa_size	= args.r13;
+		break;
+	case TDVMCALL_CPUID:
+		out->cpuid_leaf		= args.r12;
+		out->cpuid_subleaf	= args.r13;
+		break;
+	case TDVMCALL_IO:
+		out->io_size		= args.r12;
+		out->io_direction	= args.r13 ? TDX_WRITE : TDX_READ;
+		out->io_port		= args.r14;
+		out->io_value		= args.r15;
+		break;
+	case TDVMCALL_RDMSR:
+		out->msr		= args.r12;
+		break;
+	case TDVMCALL_MMIO:
+		out->mmio_size		= args.r12;
+		out->mmio_direction	= args.r13 ? TDX_WRITE : TDX_READ;
+		out->mmio_addr		= args.r14;
+		out->mmio_value		= args.r15;
+		break;
+	case TDVMCALL_NONE:
+		break;
+	case TDVMCALL_GET_QUOTE:
+		out->shared_gpa		= args.r12;
+		out->shared_gpa_size	= args.r13;
+		break;
+	case TDVMCALL_REPORT_FATAL_ERROR:
+		out->err_codes		= args.r12;
+		out->err_data_gpa	= args.r13;
+		out->err_data[0]	= args.r14;
+		out->err_data[1]	= args.r15;
+		out->err_data[2]	= args.rbx;
+		out->err_data[3]	= args.rdi;
+		out->err_data[4]	= args.rsi;
+		out->err_data[5]	= args.r8;
+		out->err_data[6]	= args.r9;
+		out->err_data[7]	= args.rdx;
+		break;
+	case TDVMCALL_HLT:
+		out->intr_blocked_flag	= args.r12;
+		break;
+	case TDVMCALL_WRMSR:
+		out->msr		= args.r12;
+		out->write_value	= args.r13;
+		break;
+	}
+
+	return ret;
 }
 EXPORT_SYMBOL_GPL(tdh_vp_enter);
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Dave Hansen 1 year, 1 month ago
On 12/11/24 10:43, Adrian Hunter wrote:
...
> -	size = tdvmcall_a0_read(vcpu);
> -	write = tdvmcall_a1_read(vcpu);
> -	port = tdvmcall_a2_read(vcpu);
> +	size  = tdx->vp_enter_out.io_size;
> +	write = tdx->vp_enter_out.io_direction == TDX_WRITE;
> +	port  = tdx->vp_enter_out.io_port;
...> +	case TDVMCALL_IO:
> +		out->io_size		= args.r12;
> +		out->io_direction	= args.r13 ? TDX_WRITE : TDX_READ;
> +		out->io_port		= args.r14;
> +		out->io_value		= args.r15;
> +		break;

I honestly don't understand the need for the abstracted structure to sit
in the middle. It doesn't get stored or serialized or anything, right?
So why have _another_ structure?

Why can't this just be (for instance):

	size = tdx->foo.r12;

?

Basically, you hand around the raw arguments until you need to use them.
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 1 month ago
On 13/12/24 18:16, Dave Hansen wrote:
> On 12/11/24 10:43, Adrian Hunter wrote:
> ...
>> -	size = tdvmcall_a0_read(vcpu);
>> -	write = tdvmcall_a1_read(vcpu);
>> -	port = tdvmcall_a2_read(vcpu);
>> +	size  = tdx->vp_enter_out.io_size;
>> +	write = tdx->vp_enter_out.io_direction == TDX_WRITE;
>> +	port  = tdx->vp_enter_out.io_port;
> ...> +	case TDVMCALL_IO:
>> +		out->io_size		= args.r12;
>> +		out->io_direction	= args.r13 ? TDX_WRITE : TDX_READ;
>> +		out->io_port		= args.r14;
>> +		out->io_value		= args.r15;
>> +		break;
> 
> I honestly don't understand the need for the abstracted structure to sit
> in the middle. It doesn't get stored or serialized or anything, right?
> So why have _another_ structure?
> 
> Why can't this just be (for instance):
> 
> 	size = tdx->foo.r12;
> 
> ?
> 
> Basically, you hand around the raw arguments until you need to use them.

That sounds like what we have at present?  That is:

u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
{
	args->rcx = tdvpr;

	return __seamcall_saved_ret(TDH_VP_ENTER, args);
}

And then either add Rick's struct tdx_vp?  Like so:

u64 tdh_vp_enter(struct tdx_vp *vp, struct tdx_module_args *args)
{
	args->rcx = tdx_tdvpr_pa(vp);

	return __seamcall_saved_ret(TDH_VP_ENTER, args);
}

Or leave it to the caller:

u64 tdh_vp_enter(struct tdx_module_args *args)
{
	return __seamcall_saved_ret(TDH_VP_ENTER, args);
}

Or forget the wrapper altogether, and let KVM call
__seamcall_saved_ret() ?
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Dave Hansen 1 year, 1 month ago
On 12/13/24 08:30, Adrian Hunter wrote:
> On 13/12/24 18:16, Dave Hansen wrote:
>> On 12/11/24 10:43, Adrian Hunter wrote:
>> ...
>>> -	size = tdvmcall_a0_read(vcpu);
>>> -	write = tdvmcall_a1_read(vcpu);
>>> -	port = tdvmcall_a2_read(vcpu);
>>> +	size  = tdx->vp_enter_out.io_size;
>>> +	write = tdx->vp_enter_out.io_direction == TDX_WRITE;
>>> +	port  = tdx->vp_enter_out.io_port;
>> ...> +	case TDVMCALL_IO:
>>> +		out->io_size		= args.r12;
>>> +		out->io_direction	= args.r13 ? TDX_WRITE : TDX_READ;
>>> +		out->io_port		= args.r14;
>>> +		out->io_value		= args.r15;
>>> +		break;
>>
>> I honestly don't understand the need for the abstracted structure to sit
>> in the middle. It doesn't get stored or serialized or anything, right?
>> So why have _another_ structure?
>>
>> Why can't this just be (for instance):
>>
>> 	size = tdx->foo.r12;
>>
>> ?
>>
>> Basically, you hand around the raw arguments until you need to use them.
> 
> That sounds like what we have at present?  That is:
> 
> u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
> {
> 	args->rcx = tdvpr;
> 
> 	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> }
> 
> And then either add Rick's struct tdx_vp?  Like so:
> 
> u64 tdh_vp_enter(struct tdx_vp *vp, struct tdx_module_args *args)
> {
> 	args->rcx = tdx_tdvpr_pa(vp);
> 
> 	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> }
> 
> Or leave it to the caller:
> 
> u64 tdh_vp_enter(struct tdx_module_args *args)
> {
> 	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> }
> 
> Or forget the wrapper altogether, and let KVM call
> __seamcall_saved_ret() ?

Rick's version, please.

I don't want __seamcall_saved_ret() exported to modules. I want to at
least have a clean boundary beyond which __seamcall_saved_ret() is not
exposed.

My nit with the "u64 tdvpr" version was that there's zero type safety.

The tdvp-less tdh_vp_enter() is even *less* safe of a calling convention
and also requires that each caller do tdx_tdvpr_pa() or equivalent.

But I feel like I'm repeating myself a bit at this point.
Re: [PATCH RFC 1/7] x86/virt/tdx: Add SEAMCALL wrapper to enter/exit TDX guest
Posted by Adrian Hunter 1 year, 1 month ago
On 11/12/24 20:43, Adrian Hunter wrote:
> The diff below shows another alternative.  This time using
> structs not a union.  The structs are easier to read than
> the union, and require copying arguments, which also allows
> using types that have sizes other than a GPR's (u64) size.

Dave, any comments on this one?

> 
> diff --git a/arch/x86/include/asm/shared/tdx.h b/arch/x86/include/asm/shared/tdx.h
> index 192ae798b214..85f87d90ac89 100644
> --- a/arch/x86/include/asm/shared/tdx.h
> +++ b/arch/x86/include/asm/shared/tdx.h
> @@ -21,20 +21,6 @@
>  /* TDCS fields. To be used by TDG.VM.WR and TDG.VM.RD module calls */
>  #define TDCS_NOTIFY_ENABLES		0x9100000000000010
>  
> -/* TDX hypercall Leaf IDs */
> -#define TDVMCALL_GET_TD_VM_CALL_INFO	0x10000
> -#define TDVMCALL_MAP_GPA		0x10001
> -#define TDVMCALL_GET_QUOTE		0x10002
> -#define TDVMCALL_REPORT_FATAL_ERROR	0x10003
> -
> -/*
> - * TDG.VP.VMCALL Status Codes (returned in R10)
> - */
> -#define TDVMCALL_STATUS_SUCCESS		0x0000000000000000ULL
> -#define TDVMCALL_STATUS_RETRY		0x0000000000000001ULL
> -#define TDVMCALL_STATUS_INVALID_OPERAND	0x8000000000000000ULL
> -#define TDVMCALL_STATUS_ALIGN_ERROR	0x8000000000000002ULL
> -
>  /*
>   * Bitmasks of exposed registers (with VMM).
>   */
> diff --git a/arch/x86/include/asm/tdx.h b/arch/x86/include/asm/tdx.h
> index 01409a59224d..e4a45378a84b 100644
> --- a/arch/x86/include/asm/tdx.h
> +++ b/arch/x86/include/asm/tdx.h
> @@ -33,6 +33,7 @@
>  
>  #ifndef __ASSEMBLY__
>  
> +#include <linux/kvm_types.h>
>  #include <uapi/asm/mce.h>
>  #include "tdx_global_metadata.h"
>  
> @@ -96,6 +97,7 @@ u64 __seamcall_saved_ret(u64 fn, struct tdx_module_args *args);
>  void tdx_init(void);
>  
>  #include <asm/archrandom.h>
> +#include <asm/vmx.h>
>  
>  typedef u64 (*sc_func_t)(u64 fn, struct tdx_module_args *args);
>  
> @@ -123,8 +125,122 @@ const struct tdx_sys_info *tdx_get_sysinfo(void);
>  int tdx_guest_keyid_alloc(void);
>  void tdx_guest_keyid_free(unsigned int keyid);
>  
> +/* TDG.VP.VMCALL Sub-function */
> +enum tdvmcall_subfn {
> +	TDVMCALL_NONE			= -1, /* Not a TDG.VP.VMCALL */
> +	TDVMCALL_GET_TD_VM_CALL_INFO	= 0x10000,
> +	TDVMCALL_MAP_GPA		= 0x10001,
> +	TDVMCALL_GET_QUOTE		= 0x10002,
> +	TDVMCALL_REPORT_FATAL_ERROR	= 0x10003,
> +	TDVMCALL_CPUID			= EXIT_REASON_CPUID,
> +	TDVMCALL_HLT			= EXIT_REASON_HLT,
> +	TDVMCALL_IO			= EXIT_REASON_IO_INSTRUCTION,
> +	TDVMCALL_RDMSR			= EXIT_REASON_MSR_READ,
> +	TDVMCALL_WRMSR			= EXIT_REASON_MSR_WRITE,
> +	TDVMCALL_MMIO			= EXIT_REASON_EPT_VIOLATION,
> +};
> +
> +enum tdx_io_direction {
> +	TDX_READ,
> +	TDX_WRITE
> +};
> +
> +/* TDG.VP.VMCALL Sub-function Completion Status Codes */
> +enum tdvmcall_status {
> +	TDVMCALL_STATUS_SUCCESS		= 0x0000000000000000ULL,
> +	TDVMCALL_STATUS_RETRY		= 0x0000000000000001ULL,
> +	TDVMCALL_STATUS_INVALID_OPERAND	= 0x8000000000000000ULL,
> +	TDVMCALL_STATUS_ALIGN_ERROR	= 0x8000000000000002ULL,
> +};
> +
> +struct tdh_vp_enter_in {
> +	/* TDG.VP.VMCALL common */
> +	enum tdvmcall_status	ret_code;
> +
> +	/* TDG.VP.VMCALL Sub-function return information */
> +
> +	/* TDVMCALL_GET_TD_VM_CALL_INFO */
> +	u64			gettdvmcallinfo[4];
> +
> +	/* TDVMCALL_MAP_GPA */
> +	gpa_t			failed_gpa;
> +
> +	/* TDVMCALL_CPUID */
> +	u32			eax;
> +	u32			ebx;
> +	u32			ecx;
> +	u32			edx;
> +
> +	/* TDVMCALL_IO (read), TDVMCALL_RDMSR or TDVMCALL_MMIO (read) */
> +	u64			value_read;
> +};
> +
> +#define TDX_ERR_DATA_SZ 8
> +
> +struct tdh_vp_enter_out {
> +	u64			exit_qual;
> +	u32			intr_info;
> +	u64			ext_exit_qual;
> +	gpa_t			gpa;
> +
> +	/* TDG.VP.VMCALL common */
> +	u32			reg_mask;
> +	u64			fn;		/* Non-zero for KVM hypercalls, zero otherwise */
> +	enum tdvmcall_subfn	subfn;
> +
> +	/* TDG.VP.VMCALL Sub-function arguments */
> +
> +	/* KVM hypercall */
> +	u64			nr;
> +	u64			p1;
> +	u64			p2;
> +	u64			p3;
> +	u64			p4;
> +
> +	/* TDVMCALL_GET_TD_VM_CALL_INFO */
> +	u64			leaf;
> +
> +	/* TDVMCALL_MAP_GPA */
> +	gpa_t			map_gpa;
> +	u64			map_gpa_size;
> +
> +	/* TDVMCALL_GET_QUOTE */
> +	gpa_t			shared_gpa;
> +	u64			shared_gpa_size;
> +
> +	/* TDVMCALL_REPORT_FATAL_ERROR */
> +	u64			err_codes;
> +	gpa_t			err_data_gpa;
> +	u64			err_data[TDX_ERR_DATA_SZ];
> +
> +	/* TDVMCALL_CPUID */
> +	u32			cpuid_leaf;
> +	u32			cpuid_subleaf;
> +
> +	/* TDVMCALL_MMIO */
> +	int			mmio_size;
> +	enum tdx_io_direction	mmio_direction;
> +	gpa_t			mmio_addr;
> +	u32			mmio_value;
> +
> +	/* TDVMCALL_HLT */
> +	bool			intr_blocked_flag;
> +
> +	/* TDVMCALL_IO_INSTRUCTION */
> +	int			io_size;
> +	enum tdx_io_direction	io_direction;
> +	u16			io_port;
> +	u32			io_value;
> +
> +	/* TDVMCALL_MSR_READ or TDVMCALL_MSR_WRITE */
> +	u32			msr;
> +
> +	/* TDVMCALL_MSR_WRITE */
> +	u64			write_value;
> +};
> +
>  /* SEAMCALL wrappers for creating/destroying/running TDX guests */
> -u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args);
> +u64 tdh_vp_enter(u64 tdvpr, const struct tdh_vp_enter_in *in, struct tdh_vp_enter_out *out);
>  u64 tdh_mng_addcx(u64 tdr, u64 tdcs);
>  u64 tdh_mem_page_add(u64 tdr, u64 gpa, u64 hpa, u64 source, u64 *rcx, u64 *rdx);
>  u64 tdh_mem_sept_add(u64 tdr, u64 gpa, u64 level, u64 hpa, u64 *rcx, u64 *rdx);
> diff --git a/arch/x86/kvm/vmx/tdx.c b/arch/x86/kvm/vmx/tdx.c
> index 218801618e9a..a8283a03fdd4 100644
> --- a/arch/x86/kvm/vmx/tdx.c
> +++ b/arch/x86/kvm/vmx/tdx.c
> @@ -256,57 +256,41 @@ static __always_inline bool tdx_check_exit_reason(struct kvm_vcpu *vcpu, u16 rea
>  
>  static __always_inline unsigned long tdexit_exit_qual(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_rcx_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.exit_qual;
>  }
>  
>  static __always_inline unsigned long tdexit_ext_exit_qual(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_rdx_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.ext_exit_qual;
>  }
>  
> -static __always_inline unsigned long tdexit_gpa(struct kvm_vcpu *vcpu)
> +static __always_inline gpa_t tdexit_gpa(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_r8_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.gpa;
>  }
>  
>  static __always_inline unsigned long tdexit_intr_info(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_r9_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.intr_info;
>  }
>  
> -#define BUILD_TDVMCALL_ACCESSORS(param, gpr)				\
> -static __always_inline							\
> -unsigned long tdvmcall_##param##_read(struct kvm_vcpu *vcpu)		\
> -{									\
> -	return kvm_##gpr##_read(vcpu);					\
> -}									\
> -static __always_inline void tdvmcall_##param##_write(struct kvm_vcpu *vcpu, \
> -						     unsigned long val)  \
> -{									\
> -	kvm_##gpr##_write(vcpu, val);					\
> -}
> -BUILD_TDVMCALL_ACCESSORS(a0, r12);
> -BUILD_TDVMCALL_ACCESSORS(a1, r13);
> -BUILD_TDVMCALL_ACCESSORS(a2, r14);
> -BUILD_TDVMCALL_ACCESSORS(a3, r15);
> -
> -static __always_inline unsigned long tdvmcall_exit_type(struct kvm_vcpu *vcpu)
> +static __always_inline unsigned long tdvmcall_fn(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_r10_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.fn;
>  }
> -static __always_inline unsigned long tdvmcall_leaf(struct kvm_vcpu *vcpu)
> +static __always_inline enum tdvmcall_subfn tdvmcall_subfn(struct kvm_vcpu *vcpu)
>  {
> -	return kvm_r11_read(vcpu);
> +	return to_tdx(vcpu)->vp_enter_out.subfn;
>  }
>  static __always_inline void tdvmcall_set_return_code(struct kvm_vcpu *vcpu,
> -						     long val)
> +						     enum tdvmcall_status val)
>  {
> -	kvm_r10_write(vcpu, val);
> +	to_tdx(vcpu)->vp_enter_in.ret_code = val;
>  }
>  static __always_inline void tdvmcall_set_return_val(struct kvm_vcpu *vcpu,
>  						    unsigned long val)
>  {
> -	kvm_r11_write(vcpu, val);
> +	to_tdx(vcpu)->vp_enter_in.value_read = val;
>  }
>  
>  static inline void tdx_hkid_free(struct kvm_tdx *kvm_tdx)
> @@ -786,10 +770,10 @@ bool tdx_interrupt_allowed(struct kvm_vcpu *vcpu)
>  	 * passes the interrupt block flag.
>  	 */
>  	if (!tdx_check_exit_reason(vcpu, EXIT_REASON_TDCALL) ||
> -	    tdvmcall_exit_type(vcpu) || tdvmcall_leaf(vcpu) != EXIT_REASON_HLT)
> +	    tdvmcall_fn(vcpu) || tdvmcall_subfn(vcpu) != TDVMCALL_HLT)
>  	    return true;
>  
> -	return !tdvmcall_a0_read(vcpu);
> +	return !to_tdx(vcpu)->vp_enter_out.intr_blocked_flag;
>  }
>  
>  bool tdx_protected_apic_has_interrupt(struct kvm_vcpu *vcpu)
> @@ -945,51 +929,10 @@ static void tdx_restore_host_xsave_state(struct kvm_vcpu *vcpu)
>  static noinstr void tdx_vcpu_enter_exit(struct kvm_vcpu *vcpu)
>  {
>  	struct vcpu_tdx *tdx = to_tdx(vcpu);
> -	struct tdx_module_args args;
>  
>  	guest_state_enter_irqoff();
>  
> -	/*
> -	 * TODO: optimization:
> -	 * - Eliminate copy between args and vcpu->arch.regs.
> -	 * - copyin/copyout registers only if (tdx->tdvmvall.regs_mask != 0)
> -	 *   which means TDG.VP.VMCALL.
> -	 */
> -	args = (struct tdx_module_args) {
> -		.rcx = tdx->tdvpr_pa,
> -#define REG(reg, REG)	.reg = vcpu->arch.regs[VCPU_REGS_ ## REG]
> -		REG(rdx, RDX),
> -		REG(r8,  R8),
> -		REG(r9,  R9),
> -		REG(r10, R10),
> -		REG(r11, R11),
> -		REG(r12, R12),
> -		REG(r13, R13),
> -		REG(r14, R14),
> -		REG(r15, R15),
> -		REG(rbx, RBX),
> -		REG(rdi, RDI),
> -		REG(rsi, RSI),
> -#undef REG
> -	};
> -
> -	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &args);
> -
> -#define REG(reg, REG)	vcpu->arch.regs[VCPU_REGS_ ## REG] = args.reg
> -	REG(rcx, RCX);
> -	REG(rdx, RDX);
> -	REG(r8,  R8);
> -	REG(r9,  R9);
> -	REG(r10, R10);
> -	REG(r11, R11);
> -	REG(r12, R12);
> -	REG(r13, R13);
> -	REG(r14, R14);
> -	REG(r15, R15);
> -	REG(rbx, RBX);
> -	REG(rdi, RDI);
> -	REG(rsi, RSI);
> -#undef REG
> +	tdx->vp_enter_ret = tdh_vp_enter(tdx->tdvpr_pa, &tdx->vp_enter_in, &tdx->vp_enter_out);
>  
>  	if (tdx_check_exit_reason(vcpu, EXIT_REASON_EXCEPTION_NMI) &&
>  	    is_nmi(tdexit_intr_info(vcpu)))
> @@ -1128,8 +1071,15 @@ static int complete_hypercall_exit(struct kvm_vcpu *vcpu)
>  
>  static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
>  {
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
>  	int r;
>  
> +	kvm_r10_write(vcpu, tdx->vp_enter_out.nr);
> +	kvm_r11_write(vcpu, tdx->vp_enter_out.p1);
> +	kvm_r12_write(vcpu, tdx->vp_enter_out.p2);
> +	kvm_r13_write(vcpu, tdx->vp_enter_out.p3);
> +	kvm_r14_write(vcpu, tdx->vp_enter_out.p4);
> +
>  	/*
>  	 * ABI for KVM tdvmcall argument:
>  	 * In Guest-Hypervisor Communication Interface(GHCI) specification,
> @@ -1137,13 +1087,12 @@ static int tdx_emulate_vmcall(struct kvm_vcpu *vcpu)
>  	 * vendor-specific.  KVM uses this for KVM hypercall.  NOTE: KVM
>  	 * hypercall number starts from one.  Zero isn't used for KVM hypercall
>  	 * number.
> -	 *
> -	 * R10: KVM hypercall number
> -	 * arguments: R11, R12, R13, R14.
>  	 */
>  	r = __kvm_emulate_hypercall(vcpu, r10, r11, r12, r13, r14, true, 0,
>  				    complete_hypercall_exit);
>  
> +	tdvmcall_set_return_code(vcpu, kvm_r10_read(vcpu));
> +
>  	return r > 0;
>  }
>  
> @@ -1161,7 +1110,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
>  
>  	if(vcpu->run->hypercall.ret) {
>  		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
> -		kvm_r11_write(vcpu, tdx->map_gpa_next);
> +		tdx->vp_enter_in.failed_gpa = tdx->map_gpa_next;
>  		return 1;
>  	}
>  
> @@ -1182,7 +1131,7 @@ static int tdx_complete_vmcall_map_gpa(struct kvm_vcpu *vcpu)
>  	if (pi_has_pending_interrupt(vcpu) ||
>  	    kvm_test_request(KVM_REQ_NMI, vcpu) || vcpu->arch.nmi_pending) {
>  		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_RETRY);
> -		kvm_r11_write(vcpu, tdx->map_gpa_next);
> +		tdx->vp_enter_in.failed_gpa = tdx->map_gpa_next;
>  		return 1;
>  	}
>  
> @@ -1214,8 +1163,8 @@ static void __tdx_map_gpa(struct vcpu_tdx * tdx)
>  static int tdx_map_gpa(struct kvm_vcpu *vcpu)
>  {
>  	struct vcpu_tdx * tdx = to_tdx(vcpu);
> -	u64 gpa = tdvmcall_a0_read(vcpu);
> -	u64 size = tdvmcall_a1_read(vcpu);
> +	u64 gpa  = tdx->vp_enter_out.map_gpa;
> +	u64 size = tdx->vp_enter_out.map_gpa_size;
>  	u64 ret;
>  
>  	/*
> @@ -1251,14 +1200,17 @@ static int tdx_map_gpa(struct kvm_vcpu *vcpu)
>  
>  error:
>  	tdvmcall_set_return_code(vcpu, ret);
> -	kvm_r11_write(vcpu, gpa);
> +	tdx->vp_enter_in.failed_gpa = gpa;
>  	return 1;
>  }
>  
>  static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
>  {
> -	u64 reg_mask = kvm_rcx_read(vcpu);
> -	u64* opt_regs;
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
> +	__u64 *data = &vcpu->run->system_event.data[0];
> +	u64 reg_mask = tdx->vp_enter_out.reg_mask;
> +	const int mask[] = {14, 15, 3, 7, 6, 8, 9, 2};
> +	int cnt = 0;
>  
>  	/*
>  	 * Skip sanity checks and let userspace decide what to do if sanity
> @@ -1266,32 +1218,20 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
>  	 */
>  	vcpu->run->exit_reason = KVM_EXIT_SYSTEM_EVENT;
>  	vcpu->run->system_event.type = KVM_SYSTEM_EVENT_TDX_FATAL;
> -	vcpu->run->system_event.ndata = 10;
>  	/* Error codes. */
> -	vcpu->run->system_event.data[0] = tdvmcall_a0_read(vcpu);
> +	data[cnt++] = tdx->vp_enter_out.err_codes;
>  	/* GPA of additional information page. */
> -	vcpu->run->system_event.data[1] = tdvmcall_a1_read(vcpu);
> +	data[cnt++] = tdx->vp_enter_out.err_data_gpa;
> +
>  	/* Information passed via registers (up to 64 bytes). */
> -	opt_regs = &vcpu->run->system_event.data[2];
> +	for (int i = 0; i < TDX_ERR_DATA_SZ; i++) {
> +		if (reg_mask & BIT_ULL(mask[i]))
> +			data[cnt++] = tdx->vp_enter_out.err_data[i];
> +		else
> +			data[cnt++] = 0;
> +	}
>  
> -#define COPY_REG(REG, MASK)						\
> -	do {								\
> -		if (reg_mask & MASK)					\
> -			*opt_regs = kvm_ ## REG ## _read(vcpu);		\
> -		else							\
> -			*opt_regs = 0;					\
> -		opt_regs++;						\
> -	} while (0)
> -
> -	/* The order is defined in GHCI. */
> -	COPY_REG(r14, BIT_ULL(14));
> -	COPY_REG(r15, BIT_ULL(15));
> -	COPY_REG(rbx, BIT_ULL(3));
> -	COPY_REG(rdi, BIT_ULL(7));
> -	COPY_REG(rsi, BIT_ULL(6));
> -	COPY_REG(r8, BIT_ULL(8));
> -	COPY_REG(r9, BIT_ULL(9));
> -	COPY_REG(rdx, BIT_ULL(2));
> +	vcpu->run->system_event.ndata = cnt;
>  
>  	/*
>  	 * Set the status code according to GHCI spec, although the vCPU may
> @@ -1305,18 +1245,18 @@ static int tdx_report_fatal_error(struct kvm_vcpu *vcpu)
>  
>  static int tdx_emulate_cpuid(struct kvm_vcpu *vcpu)
>  {
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
>  	u32 eax, ebx, ecx, edx;
>  
> -	/* EAX and ECX for cpuid is stored in R12 and R13. */
> -	eax = tdvmcall_a0_read(vcpu);
> -	ecx = tdvmcall_a1_read(vcpu);
> +	eax = tdx->vp_enter_out.cpuid_leaf;
> +	ecx = tdx->vp_enter_out.cpuid_subleaf;
>  
>  	kvm_cpuid(vcpu, &eax, &ebx, &ecx, &edx, false);
>  
> -	tdvmcall_a0_write(vcpu, eax);
> -	tdvmcall_a1_write(vcpu, ebx);
> -	tdvmcall_a2_write(vcpu, ecx);
> -	tdvmcall_a3_write(vcpu, edx);
> +	tdx->vp_enter_in.eax = eax;
> +	tdx->vp_enter_in.ebx = ebx;
> +	tdx->vp_enter_in.ecx = ecx;
> +	tdx->vp_enter_in.edx = edx;
>  
>  	tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
>  
> @@ -1356,6 +1296,7 @@ static int tdx_complete_pio_in(struct kvm_vcpu *vcpu)
>  static int tdx_emulate_io(struct kvm_vcpu *vcpu)
>  {
>  	struct x86_emulate_ctxt *ctxt = vcpu->arch.emulate_ctxt;
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
>  	unsigned long val = 0;
>  	unsigned int port;
>  	int size, ret;
> @@ -1363,9 +1304,9 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
>  
>  	++vcpu->stat.io_exits;
>  
> -	size = tdvmcall_a0_read(vcpu);
> -	write = tdvmcall_a1_read(vcpu);
> -	port = tdvmcall_a2_read(vcpu);
> +	size  = tdx->vp_enter_out.io_size;
> +	write = tdx->vp_enter_out.io_direction == TDX_WRITE;
> +	port  = tdx->vp_enter_out.io_port;
>  
>  	if (size != 1 && size != 2 && size != 4) {
>  		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
> @@ -1373,7 +1314,7 @@ static int tdx_emulate_io(struct kvm_vcpu *vcpu)
>  	}
>  
>  	if (write) {
> -		val = tdvmcall_a3_read(vcpu);
> +		val = tdx->vp_enter_out.io_value;
>  		ret = ctxt->ops->pio_out_emulated(ctxt, size, port, &val, 1);
>  	} else {
>  		ret = ctxt->ops->pio_in_emulated(ctxt, size, port, &val, 1);
> @@ -1443,14 +1384,15 @@ static inline int tdx_mmio_read(struct kvm_vcpu *vcpu, gpa_t gpa, int size)
>  
>  static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
>  {
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
>  	int size, write, r;
>  	unsigned long val;
>  	gpa_t gpa;
>  
> -	size = tdvmcall_a0_read(vcpu);
> -	write = tdvmcall_a1_read(vcpu);
> -	gpa = tdvmcall_a2_read(vcpu);
> -	val = write ? tdvmcall_a3_read(vcpu) : 0;
> +	size  = tdx->vp_enter_out.mmio_size;
> +	write = tdx->vp_enter_out.mmio_direction == TDX_WRITE;
> +	gpa   = tdx->vp_enter_out.mmio_addr;
> +	val = write ? tdx->vp_enter_out.mmio_value : 0;
>  
>  	if (size != 1 && size != 2 && size != 4 && size != 8)
>  		goto error;
> @@ -1502,7 +1444,7 @@ static int tdx_emulate_mmio(struct kvm_vcpu *vcpu)
>  
>  static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
>  {
> -	u32 index = tdvmcall_a0_read(vcpu);
> +	u32 index = to_tdx(vcpu)->vp_enter_out.msr;
>  	u64 data;
>  
>  	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_READ) ||
> @@ -1520,8 +1462,8 @@ static int tdx_emulate_rdmsr(struct kvm_vcpu *vcpu)
>  
>  static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
>  {
> -	u32 index = tdvmcall_a0_read(vcpu);
> -	u64 data = tdvmcall_a1_read(vcpu);
> +	u32 index = to_tdx(vcpu)->vp_enter_out.msr;
> +	u64 data  = to_tdx(vcpu)->vp_enter_out.write_value;
>  
>  	if (!kvm_msr_allowed(vcpu, index, KVM_MSR_FILTER_WRITE) ||
>  	    kvm_set_msr(vcpu, index, data)) {
> @@ -1537,39 +1479,41 @@ static int tdx_emulate_wrmsr(struct kvm_vcpu *vcpu)
>  
>  static int tdx_get_td_vm_call_info(struct kvm_vcpu *vcpu)
>  {
> -	if (tdvmcall_a0_read(vcpu))
> +	struct vcpu_tdx *tdx = to_tdx(vcpu);
> +
> +	if (tdx->vp_enter_out.leaf) {
>  		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_INVALID_OPERAND);
> -	else {
> +	} else {
>  		tdvmcall_set_return_code(vcpu, TDVMCALL_STATUS_SUCCESS);
> -		kvm_r11_write(vcpu, 0);
> -		tdvmcall_a0_write(vcpu, 0);
> -		tdvmcall_a1_write(vcpu, 0);
> -		tdvmcall_a2_write(vcpu, 0);
> +		tdx->vp_enter_in.gettdvmcallinfo[0] = 0;
> +		tdx->vp_enter_in.gettdvmcallinfo[1] = 0;
> +		tdx->vp_enter_in.gettdvmcallinfo[2] = 0;
> +		tdx->vp_enter_in.gettdvmcallinfo[3] = 0;
>  	}
>  	return 1;
>  }
>  
>  static int handle_tdvmcall(struct kvm_vcpu *vcpu)
>  {
> -	if (tdvmcall_exit_type(vcpu))
> +	if (tdvmcall_fn(vcpu))
>  		return tdx_emulate_vmcall(vcpu);
>  
> -	switch (tdvmcall_leaf(vcpu)) {
> +	switch (tdvmcall_subfn(vcpu)) {
>  	case TDVMCALL_MAP_GPA:
>  		return tdx_map_gpa(vcpu);
>  	case TDVMCALL_REPORT_FATAL_ERROR:
>  		return tdx_report_fatal_error(vcpu);
> -	case EXIT_REASON_CPUID:
> +	case TDVMCALL_CPUID:
>  		return tdx_emulate_cpuid(vcpu);
> -	case EXIT_REASON_HLT:
> +	case TDVMCALL_HLT:
>  		return tdx_emulate_hlt(vcpu);
> -	case EXIT_REASON_IO_INSTRUCTION:
> +	case TDVMCALL_IO:
>  		return tdx_emulate_io(vcpu);
> -	case EXIT_REASON_EPT_VIOLATION:
> +	case TDVMCALL_MMIO:
>  		return tdx_emulate_mmio(vcpu);
> -	case EXIT_REASON_MSR_READ:
> +	case TDVMCALL_RDMSR:
>  		return tdx_emulate_rdmsr(vcpu);
> -	case EXIT_REASON_MSR_WRITE:
> +	case TDVMCALL_WRMSR:
>  		return tdx_emulate_wrmsr(vcpu);
>  	case TDVMCALL_GET_TD_VM_CALL_INFO:
>  		return tdx_get_td_vm_call_info(vcpu);
> diff --git a/arch/x86/kvm/vmx/tdx.h b/arch/x86/kvm/vmx/tdx.h
> index 008180c0c30f..63d8b3359b10 100644
> --- a/arch/x86/kvm/vmx/tdx.h
> +++ b/arch/x86/kvm/vmx/tdx.h
> @@ -69,6 +69,8 @@ struct vcpu_tdx {
>  	struct list_head cpu_list;
>  
>  	u64 vp_enter_ret;
> +	struct tdh_vp_enter_in vp_enter_in;
> +	struct tdh_vp_enter_out vp_enter_out;
>  
>  	enum vcpu_tdx_state state;
>  
> diff --git a/arch/x86/virt/vmx/tdx/tdx.c b/arch/x86/virt/vmx/tdx/tdx.c
> index 16e0b598c4ec..895d9ea4aeba 100644
> --- a/arch/x86/virt/vmx/tdx/tdx.c
> +++ b/arch/x86/virt/vmx/tdx/tdx.c
> @@ -33,6 +33,7 @@
>  #include <asm/msr-index.h>
>  #include <asm/msr.h>
>  #include <asm/cpufeature.h>
> +#include <asm/vmx.h>
>  #include <asm/tdx.h>
>  #include <asm/cpu_device_id.h>
>  #include <asm/processor.h>
> @@ -1600,11 +1601,122 @@ static inline u64 tdx_seamcall_sept(u64 op, struct tdx_module_args *in)
>  	return ret;
>  }
>  
> -noinstr u64 tdh_vp_enter(u64 tdvpr, struct tdx_module_args *args)
> +noinstr u64 tdh_vp_enter(u64 tdvpr, const struct tdh_vp_enter_in *in, struct tdh_vp_enter_out *out)
>  {
> -	args->rcx = tdvpr;
> +	struct tdx_module_args args = {
> +		.rcx = tdvpr,
> +		.r10 = in->ret_code,
> +	};
> +	u64 ret;
>  
> -	return __seamcall_saved_ret(TDH_VP_ENTER, args);
> +	/* If previous exit was TDG.VP.VMCALL */
> +	switch (out->subfn) {
> +	case TDVMCALL_GET_TD_VM_CALL_INFO:
> +		args.r11 = in->gettdvmcallinfo[0];
> +		args.r12 = in->gettdvmcallinfo[1];
> +		args.r13 = in->gettdvmcallinfo[2];
> +		args.r14 = in->gettdvmcallinfo[3];
> +		break;
> +	case TDVMCALL_MAP_GPA:
> +		args.r11 = in->failed_gpa;
> +		break;
> +	case TDVMCALL_CPUID:
> +		args.r12 = in->eax;
> +		args.r13 = in->ebx;
> +		args.r14 = in->ecx;
> +		args.r15 = in->edx;
> +		break;
> +	case TDVMCALL_IO:
> +	case TDVMCALL_RDMSR:
> +	case TDVMCALL_MMIO:
> +		args.r11 = in->value_read;
> +		break;
> +	case TDVMCALL_NONE:
> +	case TDVMCALL_GET_QUOTE:
> +	case TDVMCALL_REPORT_FATAL_ERROR:
> +	case TDVMCALL_HLT:
> +	case TDVMCALL_WRMSR:
> +		break;
> +	}
> +
> +	ret = __seamcall_saved_ret(TDH_VP_ENTER, &args);
> +
> +	if ((u16)ret == EXIT_REASON_TDCALL) {
> +		out->reg_mask		= args.rcx;
> +		out->fn = args.r10;
> +		if (out->fn) {
> +			out->nr		= args.r10;
> +			out->p1		= args.r11;
> +			out->p2		= args.r12;
> +			out->p3		= args.r13;
> +			out->p4		= args.r14;
> +			out->subfn	= TDVMCALL_NONE;
> +		} else {
> +			out->subfn	= args.r11;
> +		}
> +	} else {
> +		out->exit_qual		= args.rcx;
> +		out->ext_exit_qual	= args.rdx;
> +		out->gpa		= args.r8;
> +		out->intr_info		= args.r9;
> +		out->subfn		= TDVMCALL_NONE;
> +	}
> +
> +	switch (out->subfn) {
> +	case TDVMCALL_GET_TD_VM_CALL_INFO:
> +		out->leaf		= args.r12;
> +		break;
> +	case TDVMCALL_MAP_GPA:
> +		out->map_gpa		= args.r12;
> +		out->map_gpa_size	= args.r13;
> +		break;
> +	case TDVMCALL_CPUID:
> +		out->cpuid_leaf		= args.r12;
> +		out->cpuid_subleaf	= args.r13;
> +		break;
> +	case TDVMCALL_IO:
> +		out->io_size		= args.r12;
> +		out->io_direction	= args.r13 ? TDX_WRITE : TDX_READ;
> +		out->io_port		= args.r14;
> +		out->io_value		= args.r15;
> +		break;
> +	case TDVMCALL_RDMSR:
> +		out->msr		= args.r12;
> +		break;
> +	case TDVMCALL_MMIO:
> +		out->mmio_size		= args.r12;
> +		out->mmio_direction	= args.r13 ? TDX_WRITE : TDX_READ;
> +		out->mmio_addr		= args.r14;
> +		out->mmio_value		= args.r15;
> +		break;
> +	case TDVMCALL_NONE:
> +		break;
> +	case TDVMCALL_GET_QUOTE:
> +		out->shared_gpa		= args.r12;
> +		out->shared_gpa_size	= args.r13;
> +		break;
> +	case TDVMCALL_REPORT_FATAL_ERROR:
> +		out->err_codes		= args.r12;
> +		out->err_data_gpa	= args.r13;
> +		out->err_data[0]	= args.r14;
> +		out->err_data[1]	= args.r15;
> +		out->err_data[2]	= args.rbx;
> +		out->err_data[3]	= args.rdi;
> +		out->err_data[4]	= args.rsi;
> +		out->err_data[5]	= args.r8;
> +		out->err_data[6]	= args.r9;
> +		out->err_data[7]	= args.rdx;
> +		break;
> +	case TDVMCALL_HLT:
> +		out->intr_blocked_flag	= args.r12;
> +		break;
> +	case TDVMCALL_WRMSR:
> +		out->msr		= args.r12;
> +		out->write_value	= args.r13;
> +		break;
> +	}
> +
> +	return ret;
>  }
>  EXPORT_SYMBOL_GPL(tdh_vp_enter);
>