Track the kvm pointer and its refcount in viommu core. The kvm pointer
will be used later to support TSM Bind feature, which tells the secure
firmware the connection between a vPCI device and a CoCo VM.
There is existing need to reference kvm pointer in viommu [1], but in
that series kvm pointer is used & tracked in platform iommu drivers.
While in Confidential Computing (CC) case, viommu should manage a
generic routine for TSM Bind, i.e. call pci_tsm_bind(pdev, kvm, tdi_id)
So it is better the viommu core keeps and tracks the kvm pointer.
[1] https://lore.kernel.org/all/20250319173202.78988-5-shameerali.kolothum.thodi@huawei.com/
Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
Signed-off-by: Xu Yilun <yilun.xu@linux.intel.com>
---
drivers/iommu/iommufd/viommu.c | 62 ++++++++++++++++++++++++++++++++++
include/linux/iommufd.h | 3 ++
2 files changed, 65 insertions(+)
diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
index 488905989b7c..2fcef3f8d1a5 100644
--- a/drivers/iommu/iommufd/viommu.c
+++ b/drivers/iommu/iommufd/viommu.c
@@ -1,8 +1,68 @@
// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
*/
+#if IS_ENABLED(CONFIG_KVM)
+#include <linux/kvm_host.h>
+#endif
+
#include "iommufd_private.h"
+#if IS_ENABLED(CONFIG_KVM)
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+ void (*pfn)(struct kvm *kvm);
+ bool (*fn)(struct kvm *kvm);
+ bool ret;
+
+ if (!kvm)
+ return;
+
+ pfn = symbol_get(kvm_put_kvm);
+ if (WARN_ON(!pfn))
+ return;
+
+ fn = symbol_get(kvm_get_kvm_safe);
+ if (WARN_ON(!fn)) {
+ symbol_put(kvm_put_kvm);
+ return;
+ }
+
+ ret = fn(kvm);
+ symbol_put(kvm_get_kvm_safe);
+ if (!ret) {
+ symbol_put(kvm_put_kvm);
+ return;
+ }
+
+ viommu->put_kvm = pfn;
+ viommu->kvm = kvm;
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+ if (!viommu->kvm)
+ return;
+
+ if (WARN_ON(!viommu->put_kvm))
+ goto clear;
+
+ viommu->put_kvm(viommu->kvm);
+ viommu->put_kvm = NULL;
+ symbol_put(kvm_put_kvm);
+
+clear:
+ viommu->kvm = NULL;
+}
+#else
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+}
+#endif
+
void iommufd_viommu_destroy(struct iommufd_object *obj)
{
struct iommufd_viommu *viommu =
@@ -10,6 +70,7 @@ void iommufd_viommu_destroy(struct iommufd_object *obj)
if (viommu->ops && viommu->ops->destroy)
viommu->ops->destroy(viommu);
+ viommu_put_kvm(viommu);
refcount_dec(&viommu->hwpt->common.obj.users);
xa_destroy(&viommu->vdevs);
}
@@ -68,6 +129,7 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
* on its own.
*/
viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev);
+ viommu_get_kvm_safe(viommu, idev->kvm);
cmd->out_viommu_id = viommu->obj.id;
rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
diff --git a/include/linux/iommufd.h b/include/linux/iommufd.h
index 2b2d6095309c..2712421802b9 100644
--- a/include/linux/iommufd.h
+++ b/include/linux/iommufd.h
@@ -104,6 +104,9 @@ struct iommufd_viommu {
struct rw_semaphore veventqs_rwsem;
unsigned int type;
+
+ struct kvm *kvm;
+ void (*put_kvm)(struct kvm *kvm);
};
/**
--
2.25.1