[PATCH v6 22/26] device/dax: Properly refcount device dax pages when mapping

Alistair Popple posted 26 patches 1 year, 1 month ago
There is a newer version of this series
[PATCH v6 22/26] device/dax: Properly refcount device dax pages when mapping
Posted by Alistair Popple 1 year, 1 month ago
Device DAX pages are currently not reference counted when mapped,
instead relying on the devmap PTE bit to ensure mapping code will not
get/put references. This requires special handling in various page
table walkers, particularly GUP, to manage references on the
underlying pgmap to ensure the pages remain valid.

However there is no reason these pages can't be refcounted properly at
map time. Doning so eliminates the need for the devmap PTE bit,
freeing up a precious PTE bit. It also simplifies GUP as it no longer
needs to manage the special pgmap references and can instead just
treat the pages normally as defined by vm_normal_page().

Signed-off-by: Alistair Popple <apopple@nvidia.com>
---
 drivers/dax/device.c | 15 +++++++++------
 mm/memremap.c        | 13 ++++++-------
 2 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/drivers/dax/device.c b/drivers/dax/device.c
index 6d74e62..fd22dbf 100644
--- a/drivers/dax/device.c
+++ b/drivers/dax/device.c
@@ -126,11 +126,12 @@ static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+	pfn = phys_to_pfn_t(phys, 0);
 
 	dax_set_mapping(vmf, pfn, fault_size);
 
-	return vmf_insert_mixed(vmf->vma, vmf->address, pfn);
+	return vmf_insert_page_mkwrite(vmf, pfn_t_to_page(pfn),
+					vmf->flags & FAULT_FLAG_WRITE);
 }
 
 static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
@@ -169,11 +170,12 @@ static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+	pfn = phys_to_pfn_t(phys, 0);
 
 	dax_set_mapping(vmf, pfn, fault_size);
 
-	return vmf_insert_pfn_pmd(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
+	return vmf_insert_folio_pmd(vmf, page_folio(pfn_t_to_page(pfn)),
+				vmf->flags & FAULT_FLAG_WRITE);
 }
 
 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
@@ -214,11 +216,12 @@ static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+	pfn = phys_to_pfn_t(phys, 0);
 
 	dax_set_mapping(vmf, pfn, fault_size);
 
-	return vmf_insert_pfn_pud(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
+	return vmf_insert_folio_pud(vmf, page_folio(pfn_t_to_page(pfn)),
+				vmf->flags & FAULT_FLAG_WRITE);
 }
 #else
 static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
diff --git a/mm/memremap.c b/mm/memremap.c
index 9a8879b..532a52a 100644
--- a/mm/memremap.c
+++ b/mm/memremap.c
@@ -460,11 +460,10 @@ void free_zone_device_folio(struct folio *folio)
 {
 	struct dev_pagemap *pgmap = folio->pgmap;
 
-	if (WARN_ON_ONCE(!pgmap->ops))
-		return;
-
-	if (WARN_ON_ONCE(pgmap->type != MEMORY_DEVICE_FS_DAX &&
-			 !pgmap->ops->page_free))
+	if (WARN_ON_ONCE((!pgmap->ops &&
+			  pgmap->type != MEMORY_DEVICE_GENERIC) ||
+			 (pgmap->ops && !pgmap->ops->page_free &&
+			  pgmap->type != MEMORY_DEVICE_FS_DAX)))
 		return;
 
 	mem_cgroup_uncharge(folio);
@@ -494,7 +493,8 @@ void free_zone_device_folio(struct folio *folio)
 	 * zero which indicating the page has been removed from the file
 	 * system mapping.
 	 */
-	if (pgmap->type != MEMORY_DEVICE_FS_DAX)
+	if (pgmap->type != MEMORY_DEVICE_FS_DAX &&
+	    pgmap->type != MEMORY_DEVICE_GENERIC)
 		folio->mapping = NULL;
 
 	switch (pgmap->type) {
@@ -509,7 +509,6 @@ void free_zone_device_folio(struct folio *folio)
 		 * Reset the refcount to 1 to prepare for handing out the page
 		 * again.
 		 */
-		pgmap->ops->page_free(folio_page(folio, 0));
 		folio_set_count(folio, 1);
 		break;
 
-- 
git-series 0.9.1
Re: [PATCH v6 22/26] device/dax: Properly refcount device dax pages when mapping
Posted by Dan Williams 1 year ago
Alistair Popple wrote:
> Device DAX pages are currently not reference counted when mapped,
> instead relying on the devmap PTE bit to ensure mapping code will not
> get/put references. This requires special handling in various page
> table walkers, particularly GUP, to manage references on the
> underlying pgmap to ensure the pages remain valid.
> 
> However there is no reason these pages can't be refcounted properly at
> map time. Doning so eliminates the need for the devmap PTE bit,
> freeing up a precious PTE bit. It also simplifies GUP as it no longer
> needs to manage the special pgmap references and can instead just
> treat the pages normally as defined by vm_normal_page().
> 
> Signed-off-by: Alistair Popple <apopple@nvidia.com>
> ---
>  drivers/dax/device.c | 15 +++++++++------
>  mm/memremap.c        | 13 ++++++-------
>  2 files changed, 15 insertions(+), 13 deletions(-)
> 
> diff --git a/drivers/dax/device.c b/drivers/dax/device.c
> index 6d74e62..fd22dbf 100644
> --- a/drivers/dax/device.c
> +++ b/drivers/dax/device.c
> @@ -126,11 +126,12 @@ static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
>  		return VM_FAULT_SIGBUS;
>  	}
>  
> -	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
> +	pfn = phys_to_pfn_t(phys, 0);
>  
>  	dax_set_mapping(vmf, pfn, fault_size);
>  
> -	return vmf_insert_mixed(vmf->vma, vmf->address, pfn);
> +	return vmf_insert_page_mkwrite(vmf, pfn_t_to_page(pfn),
> +					vmf->flags & FAULT_FLAG_WRITE);
>  }
>  
>  static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
> @@ -169,11 +170,12 @@ static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
>  		return VM_FAULT_SIGBUS;
>  	}
>  
> -	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
> +	pfn = phys_to_pfn_t(phys, 0);
>  
>  	dax_set_mapping(vmf, pfn, fault_size);
>  
> -	return vmf_insert_pfn_pmd(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
> +	return vmf_insert_folio_pmd(vmf, page_folio(pfn_t_to_page(pfn)),
> +				vmf->flags & FAULT_FLAG_WRITE);

This looks suspect without initializing the compound page metadata.

This might be getting compound pages by default with
CONFIG_ARCH_WANT_OPTIMIZE_DAX_VMEMMAP. The device-dax unit tests are ok
so far, but that is not super comforting until I can think about this a
bit more... but not tonight.

Might as well fix up device-dax refcounts in this series too, but I
won't ask you to do that, will send you something to include.
Re: [PATCH v6 22/26] device/dax: Properly refcount device dax pages when mapping
Posted by Alistair Popple 1 year ago
On Mon, Jan 13, 2025 at 10:12:41PM -0800, Dan Williams wrote:
> Alistair Popple wrote:
> > Device DAX pages are currently not reference counted when mapped,
> > instead relying on the devmap PTE bit to ensure mapping code will not
> > get/put references. This requires special handling in various page
> > table walkers, particularly GUP, to manage references on the
> > underlying pgmap to ensure the pages remain valid.
> > 
> > However there is no reason these pages can't be refcounted properly at
> > map time. Doning so eliminates the need for the devmap PTE bit,
> > freeing up a precious PTE bit. It also simplifies GUP as it no longer
> > needs to manage the special pgmap references and can instead just
> > treat the pages normally as defined by vm_normal_page().
> > 
> > Signed-off-by: Alistair Popple <apopple@nvidia.com>
> > ---
> >  drivers/dax/device.c | 15 +++++++++------
> >  mm/memremap.c        | 13 ++++++-------
> >  2 files changed, 15 insertions(+), 13 deletions(-)
> > 
> > diff --git a/drivers/dax/device.c b/drivers/dax/device.c
> > index 6d74e62..fd22dbf 100644
> > --- a/drivers/dax/device.c
> > +++ b/drivers/dax/device.c
> > @@ -126,11 +126,12 @@ static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
> >  		return VM_FAULT_SIGBUS;
> >  	}
> >  
> > -	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
> > +	pfn = phys_to_pfn_t(phys, 0);
> >  
> >  	dax_set_mapping(vmf, pfn, fault_size);
> >  
> > -	return vmf_insert_mixed(vmf->vma, vmf->address, pfn);
> > +	return vmf_insert_page_mkwrite(vmf, pfn_t_to_page(pfn),
> > +					vmf->flags & FAULT_FLAG_WRITE);
> >  }
> >  
> >  static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
> > @@ -169,11 +170,12 @@ static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
> >  		return VM_FAULT_SIGBUS;
> >  	}
> >  
> > -	pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
> > +	pfn = phys_to_pfn_t(phys, 0);
> >  
> >  	dax_set_mapping(vmf, pfn, fault_size);
> >  
> > -	return vmf_insert_pfn_pmd(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
> > +	return vmf_insert_folio_pmd(vmf, page_folio(pfn_t_to_page(pfn)),
> > +				vmf->flags & FAULT_FLAG_WRITE);
> 
> This looks suspect without initializing the compound page metadata.

I initially wondered about this too, however I think the compound page metadata
should be initialised by memmap_init_zone_device(). That said I kind of get lost
in all the namespace/CXL/PMEM/DAX drivers in the stack so maybe I've overlooked
something.
 
> This might be getting compound pages by default with
> CONFIG_ARCH_WANT_OPTIMIZE_DAX_VMEMMAP. The device-dax unit tests are ok
> so far, but that is not super comforting until I can think about this a
> bit more... but not tonight.

From my reading of the code I don't _think_
CONFIG_ARCH_WANT_OPTIMIZE_DAX_VMEMMAP would change whether or not we got
compound pages by default, just that if we did some of the (tail?) pages may
refer to the same physical struct page.

> Might as well fix up device-dax refcounts in this series too, but I
> won't ask you to do that, will send you something to include.

Eh. That should be relatively straight forward. But then I thought that about FS
DAX too :-)