[PATCH v2 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN

Shiyang Ruan posted 8 patches 3 years, 2 months ago
[PATCH v2 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN
Posted by Shiyang Ruan 3 years, 2 months ago
If srcmap contains invalid data, such as HOLE and UNWRITTEN, the dest
page should be zeroed.  Otherwise, since it's a pmem, old data may
remains on the dest page, the result of CoW will be incorrect.

The function name is also not easy to understand, rename it to
"dax_iomap_copy_around()", which means it copys data around the range.

Signed-off-by: Shiyang Ruan <ruansy.fnst@fujitsu.com>
---
 fs/dax.c | 78 ++++++++++++++++++++++++++++++++++----------------------
 1 file changed, 48 insertions(+), 30 deletions(-)

diff --git a/fs/dax.c b/fs/dax.c
index 482dda85ccaf..6b6e07ad8d80 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -1092,7 +1092,7 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
 }
 
 /**
- * dax_iomap_cow_copy - Copy the data from source to destination before write
+ * dax_iomap_copy_around - Copy the data from source to destination before write
  * @pos:	address to do copy from.
  * @length:	size of copy operation.
  * @align_size:	aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
@@ -1101,35 +1101,50 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
  *
  * This can be called from two places. Either during DAX write fault (page
  * aligned), to copy the length size data to daddr. Or, while doing normal DAX
- * write operation, dax_iomap_actor() might call this to do the copy of either
+ * write operation, dax_iomap_iter() might call this to do the copy of either
  * start or end unaligned address. In the latter case the rest of the copy of
- * aligned ranges is taken care by dax_iomap_actor() itself.
+ * aligned ranges is taken care by dax_iomap_iter() itself.
+ * If the srcmap contains invalid data, such as HOLE and UNWRITTEN, zero the
+ * area to make sure no old data remains.
  */
-static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
+static int dax_iomap_copy_around(loff_t pos, uint64_t length, size_t align_size,
 		const struct iomap *srcmap, void *daddr)
 {
 	loff_t head_off = pos & (align_size - 1);
 	size_t size = ALIGN(head_off + length, align_size);
 	loff_t end = pos + length;
 	loff_t pg_end = round_up(end, align_size);
+	/* copy_all is usually in page fault case */
 	bool copy_all = head_off == 0 && end == pg_end;
+	/* zero the edges if srcmap is a HOLE or IOMAP_UNWRITTEN */
+	bool zero_edge = srcmap->flags & IOMAP_F_SHARED ||
+			 srcmap->type == IOMAP_UNWRITTEN;
 	void *saddr = 0;
 	int ret = 0;
 
-	ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
-	if (ret)
-		return ret;
+	if (!zero_edge) {
+		ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
+		if (ret)
+			return ret;
+	}
 
 	if (copy_all) {
-		ret = copy_mc_to_kernel(daddr, saddr, length);
-		return ret ? -EIO : 0;
+		if (zero_edge)
+			memset(daddr, 0, size);
+		else
+			ret = copy_mc_to_kernel(daddr, saddr, length);
+		goto out;
 	}
 
 	/* Copy the head part of the range */
 	if (head_off) {
-		ret = copy_mc_to_kernel(daddr, saddr, head_off);
-		if (ret)
-			return -EIO;
+		if (zero_edge)
+			memset(daddr, 0, head_off);
+		else {
+			ret = copy_mc_to_kernel(daddr, saddr, head_off);
+			if (ret)
+				return -EIO;
+		}
 	}
 
 	/* Copy the tail part of the range */
@@ -1137,12 +1152,19 @@ static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
 		loff_t tail_off = head_off + length;
 		loff_t tail_len = pg_end - end;
 
-		ret = copy_mc_to_kernel(daddr + tail_off, saddr + tail_off,
-					tail_len);
-		if (ret)
-			return -EIO;
+		if (zero_edge)
+			memset(daddr + tail_off, 0, tail_len);
+		else {
+			ret = copy_mc_to_kernel(daddr + tail_off,
+						saddr + tail_off, tail_len);
+			if (ret)
+				return -EIO;
+		}
 	}
-	return 0;
+out:
+	if (zero_edge)
+		dax_flush(srcmap->dax_dev, daddr, size);
+	return ret ? -EIO : 0;
 }
 
 /*
@@ -1241,13 +1263,10 @@ static int dax_memzero(struct iomap_iter *iter, loff_t pos, size_t size)
 	if (ret < 0)
 		return ret;
 	memset(kaddr + offset, 0, size);
-	if (srcmap->addr != iomap->addr) {
-		ret = dax_iomap_cow_copy(pos, size, PAGE_SIZE, srcmap,
-					 kaddr);
-		if (ret < 0)
-			return ret;
-		dax_flush(iomap->dax_dev, kaddr, PAGE_SIZE);
-	} else
+	if (iomap->flags & IOMAP_F_SHARED)
+		ret = dax_iomap_copy_around(pos, size, PAGE_SIZE, srcmap,
+					    kaddr);
+	else
 		dax_flush(iomap->dax_dev, kaddr + offset, size);
 	return ret;
 }
@@ -1401,8 +1420,8 @@ static loff_t dax_iomap_iter(const struct iomap_iter *iomi,
 		}
 
 		if (cow) {
-			ret = dax_iomap_cow_copy(pos, length, PAGE_SIZE, srcmap,
-						 kaddr);
+			ret = dax_iomap_copy_around(pos, length, PAGE_SIZE,
+						    srcmap, kaddr);
 			if (ret)
 				break;
 		}
@@ -1547,7 +1566,7 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
 		struct xa_state *xas, void **entry, bool pmd)
 {
 	const struct iomap *iomap = &iter->iomap;
-	const struct iomap *srcmap = &iter->srcmap;
+	const struct iomap *srcmap = iomap_iter_srcmap(iter);
 	size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
 	loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
 	bool write = iter->flags & IOMAP_WRITE;
@@ -1578,9 +1597,8 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
 
 	*entry = dax_insert_entry(xas, vmf, iter, *entry, pfn, entry_flags);
 
-	if (write &&
-	    srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
-		err = dax_iomap_cow_copy(pos, size, size, srcmap, kaddr);
+	if (write && iomap->flags & IOMAP_F_SHARED) {
+		err = dax_iomap_copy_around(pos, size, size, srcmap, kaddr);
 		if (err)
 			return dax_fault_return(err);
 	}
-- 
2.38.1
Re: [PATCH v2 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN
Posted by Darrick J. Wong 3 years, 2 months ago
On Thu, Dec 01, 2022 at 03:28:53PM +0000, Shiyang Ruan wrote:
> If srcmap contains invalid data, such as HOLE and UNWRITTEN, the dest
> page should be zeroed.  Otherwise, since it's a pmem, old data may
> remains on the dest page, the result of CoW will be incorrect.
> 
> The function name is also not easy to understand, rename it to
> "dax_iomap_copy_around()", which means it copys data around the range.
> 
> Signed-off-by: Shiyang Ruan <ruansy.fnst@fujitsu.com>
> ---
>  fs/dax.c | 78 ++++++++++++++++++++++++++++++++++----------------------
>  1 file changed, 48 insertions(+), 30 deletions(-)
> 
> diff --git a/fs/dax.c b/fs/dax.c
> index 482dda85ccaf..6b6e07ad8d80 100644
> --- a/fs/dax.c
> +++ b/fs/dax.c
> @@ -1092,7 +1092,7 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
>  }
>  
>  /**
> - * dax_iomap_cow_copy - Copy the data from source to destination before write
> + * dax_iomap_copy_around - Copy the data from source to destination before write

 * dax_iomap_copy_around - Prepare for an unaligned write to a
 * shared/cow page by copying the data before and after the range to be
 * written.

Other than that, this make sense,
Reviewed-by: Darrick J. Wong <djwong@kernel.org>

--D

>   * @pos:	address to do copy from.
>   * @length:	size of copy operation.
>   * @align_size:	aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
> @@ -1101,35 +1101,50 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
>   *
>   * This can be called from two places. Either during DAX write fault (page
>   * aligned), to copy the length size data to daddr. Or, while doing normal DAX
> - * write operation, dax_iomap_actor() might call this to do the copy of either
> + * write operation, dax_iomap_iter() might call this to do the copy of either
>   * start or end unaligned address. In the latter case the rest of the copy of
> - * aligned ranges is taken care by dax_iomap_actor() itself.
> + * aligned ranges is taken care by dax_iomap_iter() itself.
> + * If the srcmap contains invalid data, such as HOLE and UNWRITTEN, zero the
> + * area to make sure no old data remains.
>   */
> -static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
> +static int dax_iomap_copy_around(loff_t pos, uint64_t length, size_t align_size,
>  		const struct iomap *srcmap, void *daddr)
>  {
>  	loff_t head_off = pos & (align_size - 1);
>  	size_t size = ALIGN(head_off + length, align_size);
>  	loff_t end = pos + length;
>  	loff_t pg_end = round_up(end, align_size);
> +	/* copy_all is usually in page fault case */
>  	bool copy_all = head_off == 0 && end == pg_end;
> +	/* zero the edges if srcmap is a HOLE or IOMAP_UNWRITTEN */
> +	bool zero_edge = srcmap->flags & IOMAP_F_SHARED ||
> +			 srcmap->type == IOMAP_UNWRITTEN;
>  	void *saddr = 0;
>  	int ret = 0;
>  
> -	ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
> -	if (ret)
> -		return ret;
> +	if (!zero_edge) {
> +		ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
> +		if (ret)
> +			return ret;
> +	}
>  
>  	if (copy_all) {
> -		ret = copy_mc_to_kernel(daddr, saddr, length);
> -		return ret ? -EIO : 0;
> +		if (zero_edge)
> +			memset(daddr, 0, size);
> +		else
> +			ret = copy_mc_to_kernel(daddr, saddr, length);
> +		goto out;
>  	}
>  
>  	/* Copy the head part of the range */
>  	if (head_off) {
> -		ret = copy_mc_to_kernel(daddr, saddr, head_off);
> -		if (ret)
> -			return -EIO;
> +		if (zero_edge)
> +			memset(daddr, 0, head_off);
> +		else {
> +			ret = copy_mc_to_kernel(daddr, saddr, head_off);
> +			if (ret)
> +				return -EIO;
> +		}
>  	}
>  
>  	/* Copy the tail part of the range */
> @@ -1137,12 +1152,19 @@ static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
>  		loff_t tail_off = head_off + length;
>  		loff_t tail_len = pg_end - end;
>  
> -		ret = copy_mc_to_kernel(daddr + tail_off, saddr + tail_off,
> -					tail_len);
> -		if (ret)
> -			return -EIO;
> +		if (zero_edge)
> +			memset(daddr + tail_off, 0, tail_len);
> +		else {
> +			ret = copy_mc_to_kernel(daddr + tail_off,
> +						saddr + tail_off, tail_len);
> +			if (ret)
> +				return -EIO;
> +		}
>  	}
> -	return 0;
> +out:
> +	if (zero_edge)
> +		dax_flush(srcmap->dax_dev, daddr, size);
> +	return ret ? -EIO : 0;
>  }
>  
>  /*
> @@ -1241,13 +1263,10 @@ static int dax_memzero(struct iomap_iter *iter, loff_t pos, size_t size)
>  	if (ret < 0)
>  		return ret;
>  	memset(kaddr + offset, 0, size);
> -	if (srcmap->addr != iomap->addr) {
> -		ret = dax_iomap_cow_copy(pos, size, PAGE_SIZE, srcmap,
> -					 kaddr);
> -		if (ret < 0)
> -			return ret;
> -		dax_flush(iomap->dax_dev, kaddr, PAGE_SIZE);
> -	} else
> +	if (iomap->flags & IOMAP_F_SHARED)
> +		ret = dax_iomap_copy_around(pos, size, PAGE_SIZE, srcmap,
> +					    kaddr);
> +	else
>  		dax_flush(iomap->dax_dev, kaddr + offset, size);
>  	return ret;
>  }
> @@ -1401,8 +1420,8 @@ static loff_t dax_iomap_iter(const struct iomap_iter *iomi,
>  		}
>  
>  		if (cow) {
> -			ret = dax_iomap_cow_copy(pos, length, PAGE_SIZE, srcmap,
> -						 kaddr);
> +			ret = dax_iomap_copy_around(pos, length, PAGE_SIZE,
> +						    srcmap, kaddr);
>  			if (ret)
>  				break;
>  		}
> @@ -1547,7 +1566,7 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
>  		struct xa_state *xas, void **entry, bool pmd)
>  {
>  	const struct iomap *iomap = &iter->iomap;
> -	const struct iomap *srcmap = &iter->srcmap;
> +	const struct iomap *srcmap = iomap_iter_srcmap(iter);
>  	size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
>  	loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
>  	bool write = iter->flags & IOMAP_WRITE;
> @@ -1578,9 +1597,8 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
>  
>  	*entry = dax_insert_entry(xas, vmf, iter, *entry, pfn, entry_flags);
>  
> -	if (write &&
> -	    srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
> -		err = dax_iomap_cow_copy(pos, size, size, srcmap, kaddr);
> +	if (write && iomap->flags & IOMAP_F_SHARED) {
> +		err = dax_iomap_copy_around(pos, size, size, srcmap, kaddr);
>  		if (err)
>  			return dax_fault_return(err);
>  	}
> -- 
> 2.38.1
>
Re: [PATCH v2 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN
Posted by Andrew Morton 3 years, 2 months ago
On Thu, 1 Dec 2022 15:58:11 -0800 "Darrick J. Wong" <djwong@kernel.org> wrote:

> > --- a/fs/dax.c
> > +++ b/fs/dax.c
> > @@ -1092,7 +1092,7 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
> >  }
> >  
> >  /**
> > - * dax_iomap_cow_copy - Copy the data from source to destination before write
> > + * dax_iomap_copy_around - Copy the data from source to destination before write
> 
>  * dax_iomap_copy_around - Prepare for an unaligned write to a
>  * shared/cow page by copying the data before and after the range to be
>  * written.

Thanks, I added this:

--- a/fs/dax.c~fsdax-zero-the-edges-if-source-is-hole-or-unwritten-fix
+++ a/fs/dax.c
@@ -1092,7 +1092,8 @@ out:
 }
 
 /**
- * dax_iomap_copy_around - Copy the data from source to destination before write
+ * dax_iomap_copy_around - Prepare for an unaligned write to a shared/cow page
+ * by copying the data before and after the range to be written.
  * @pos:	address to do copy from.
  * @length:	size of copy operation.
  * @align_size:	aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
_
[PATCH v2.1 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN
Posted by Shiyang Ruan 3 years, 2 months ago
If srcmap contains invalid data, such as HOLE and UNWRITTEN, the dest
page should be zeroed.  Otherwise, since it's a pmem, old data may
remains on the dest page, the result of CoW will be incorrect.

The function name is also not easy to understand, rename it to
"dax_iomap_copy_around()", which means it copys data around the range.

Signed-off-by: Shiyang Ruan <ruansy.fnst@fujitsu.com>
Reviewed-by: Darrick J. Wong <djwong@kernel.org>
---
 fs/dax.c | 79 +++++++++++++++++++++++++++++++++++---------------------
 1 file changed, 49 insertions(+), 30 deletions(-)

diff --git a/fs/dax.c b/fs/dax.c
index a77739f2abe7..f12645d6f3c8 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -1092,7 +1092,8 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
 }
 
 /**
- * dax_iomap_cow_copy - Copy the data from source to destination before write
+ * dax_iomap_copy_around - Prepare for an unaligned write to a shared/cow page
+ * by copying the data before and after the range to be written.
  * @pos:	address to do copy from.
  * @length:	size of copy operation.
  * @align_size:	aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
@@ -1101,35 +1102,50 @@ static int dax_iomap_direct_access(const struct iomap *iomap, loff_t pos,
  *
  * This can be called from two places. Either during DAX write fault (page
  * aligned), to copy the length size data to daddr. Or, while doing normal DAX
- * write operation, dax_iomap_actor() might call this to do the copy of either
+ * write operation, dax_iomap_iter() might call this to do the copy of either
  * start or end unaligned address. In the latter case the rest of the copy of
- * aligned ranges is taken care by dax_iomap_actor() itself.
+ * aligned ranges is taken care by dax_iomap_iter() itself.
+ * If the srcmap contains invalid data, such as HOLE and UNWRITTEN, zero the
+ * area to make sure no old data remains.
  */
-static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
+static int dax_iomap_copy_around(loff_t pos, uint64_t length, size_t align_size,
 		const struct iomap *srcmap, void *daddr)
 {
 	loff_t head_off = pos & (align_size - 1);
 	size_t size = ALIGN(head_off + length, align_size);
 	loff_t end = pos + length;
 	loff_t pg_end = round_up(end, align_size);
+	/* copy_all is usually in page fault case */
 	bool copy_all = head_off == 0 && end == pg_end;
+	/* zero the edges if srcmap is a HOLE or IOMAP_UNWRITTEN */
+	bool zero_edge = srcmap->flags & IOMAP_F_SHARED ||
+			 srcmap->type == IOMAP_UNWRITTEN;
 	void *saddr = 0;
 	int ret = 0;
 
-	ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
-	if (ret)
-		return ret;
+	if (!zero_edge) {
+		ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
+		if (ret)
+			return ret;
+	}
 
 	if (copy_all) {
-		ret = copy_mc_to_kernel(daddr, saddr, length);
-		return ret ? -EIO : 0;
+		if (zero_edge)
+			memset(daddr, 0, size);
+		else
+			ret = copy_mc_to_kernel(daddr, saddr, length);
+		goto out;
 	}
 
 	/* Copy the head part of the range */
 	if (head_off) {
-		ret = copy_mc_to_kernel(daddr, saddr, head_off);
-		if (ret)
-			return -EIO;
+		if (zero_edge)
+			memset(daddr, 0, head_off);
+		else {
+			ret = copy_mc_to_kernel(daddr, saddr, head_off);
+			if (ret)
+				return -EIO;
+		}
 	}
 
 	/* Copy the tail part of the range */
@@ -1137,12 +1153,19 @@ static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
 		loff_t tail_off = head_off + length;
 		loff_t tail_len = pg_end - end;
 
-		ret = copy_mc_to_kernel(daddr + tail_off, saddr + tail_off,
-					tail_len);
-		if (ret)
-			return -EIO;
+		if (zero_edge)
+			memset(daddr + tail_off, 0, tail_len);
+		else {
+			ret = copy_mc_to_kernel(daddr + tail_off,
+						saddr + tail_off, tail_len);
+			if (ret)
+				return -EIO;
+		}
 	}
-	return 0;
+out:
+	if (zero_edge)
+		dax_flush(srcmap->dax_dev, daddr, size);
+	return ret ? -EIO : 0;
 }
 
 /*
@@ -1241,13 +1264,10 @@ static int dax_memzero(struct iomap_iter *iter, loff_t pos, size_t size)
 	if (ret < 0)
 		return ret;
 	memset(kaddr + offset, 0, size);
-	if (srcmap->addr != iomap->addr) {
-		ret = dax_iomap_cow_copy(pos, size, PAGE_SIZE, srcmap,
-					 kaddr);
-		if (ret < 0)
-			return ret;
-		dax_flush(iomap->dax_dev, kaddr, PAGE_SIZE);
-	} else
+	if (iomap->flags & IOMAP_F_SHARED)
+		ret = dax_iomap_copy_around(pos, size, PAGE_SIZE, srcmap,
+					    kaddr);
+	else
 		dax_flush(iomap->dax_dev, kaddr + offset, size);
 	return ret;
 }
@@ -1401,8 +1421,8 @@ static loff_t dax_iomap_iter(const struct iomap_iter *iomi,
 		}
 
 		if (cow) {
-			ret = dax_iomap_cow_copy(pos, length, PAGE_SIZE, srcmap,
-						 kaddr);
+			ret = dax_iomap_copy_around(pos, length, PAGE_SIZE,
+						    srcmap, kaddr);
 			if (ret)
 				break;
 		}
@@ -1547,7 +1567,7 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
 		struct xa_state *xas, void **entry, bool pmd)
 {
 	const struct iomap *iomap = &iter->iomap;
-	const struct iomap *srcmap = &iter->srcmap;
+	const struct iomap *srcmap = iomap_iter_srcmap(iter);
 	size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
 	loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
 	bool write = iter->flags & IOMAP_WRITE;
@@ -1578,9 +1598,8 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
 
 	*entry = dax_insert_entry(xas, vmf, iter, *entry, pfn, entry_flags);
 
-	if (write &&
-	    srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
-		err = dax_iomap_cow_copy(pos, size, size, srcmap, kaddr);
+	if (write && iomap->flags & IOMAP_F_SHARED) {
+		err = dax_iomap_copy_around(pos, size, size, srcmap, kaddr);
 		if (err)
 			return dax_fault_return(err);
 	}
-- 
2.38.1
Re: [PATCH v2.1 3/8] fsdax: zero the edges if source is HOLE or UNWRITTEN
Posted by Allison Henderson 3 years, 2 months ago
On Fri, 2022-12-02 at 09:25 +0000, Shiyang Ruan wrote:
> If srcmap contains invalid data, such as HOLE and UNWRITTEN, the dest
> page should be zeroed.  Otherwise, since it's a pmem, old data may
> remains on the dest page, the result of CoW will be incorrect.
> 
> The function name is also not easy to understand, rename it to
> "dax_iomap_copy_around()", which means it copys data around the
> range.
> 
> Signed-off-by: Shiyang Ruan <ruansy.fnst@fujitsu.com>
> Reviewed-by: Darrick J. Wong <djwong@kernel.org>
> 
I think the new changes look good
Reviewed-by: Allison Henderson <allison.henderson@oracle.com>

> ---
>  fs/dax.c | 79 +++++++++++++++++++++++++++++++++++-------------------
> --
>  1 file changed, 49 insertions(+), 30 deletions(-)
> 
> diff --git a/fs/dax.c b/fs/dax.c
> index a77739f2abe7..f12645d6f3c8 100644
> --- a/fs/dax.c
> +++ b/fs/dax.c
> @@ -1092,7 +1092,8 @@ static int dax_iomap_direct_access(const struct
> iomap *iomap, loff_t pos,
>  }
>  
>  /**
> - * dax_iomap_cow_copy - Copy the data from source to destination
> before write
> + * dax_iomap_copy_around - Prepare for an unaligned write to a
> shared/cow page
> + * by copying the data before and after the range to be written.
>   * @pos:       address to do copy from.
>   * @length:    size of copy operation.
>   * @align_size:        aligned w.r.t align_size (either PMD_SIZE or
> PAGE_SIZE)
> @@ -1101,35 +1102,50 @@ static int dax_iomap_direct_access(const
> struct iomap *iomap, loff_t pos,
>   *
>   * This can be called from two places. Either during DAX write fault
> (page
>   * aligned), to copy the length size data to daddr. Or, while doing
> normal DAX
> - * write operation, dax_iomap_actor() might call this to do the copy
> of either
> + * write operation, dax_iomap_iter() might call this to do the copy
> of either
>   * start or end unaligned address. In the latter case the rest of
> the copy of
> - * aligned ranges is taken care by dax_iomap_actor() itself.
> + * aligned ranges is taken care by dax_iomap_iter() itself.
> + * If the srcmap contains invalid data, such as HOLE and UNWRITTEN,
> zero the
> + * area to make sure no old data remains.
>   */
> -static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t
> align_size,
> +static int dax_iomap_copy_around(loff_t pos, uint64_t length, size_t
> align_size,
>                 const struct iomap *srcmap, void *daddr)
>  {
>         loff_t head_off = pos & (align_size - 1);
>         size_t size = ALIGN(head_off + length, align_size);
>         loff_t end = pos + length;
>         loff_t pg_end = round_up(end, align_size);
> +       /* copy_all is usually in page fault case */
>         bool copy_all = head_off == 0 && end == pg_end;
> +       /* zero the edges if srcmap is a HOLE or IOMAP_UNWRITTEN */
> +       bool zero_edge = srcmap->flags & IOMAP_F_SHARED ||
> +                        srcmap->type == IOMAP_UNWRITTEN;
>         void *saddr = 0;
>         int ret = 0;
>  
> -       ret = dax_iomap_direct_access(srcmap, pos, size, &saddr,
> NULL);
> -       if (ret)
> -               return ret;
> +       if (!zero_edge) {
> +               ret = dax_iomap_direct_access(srcmap, pos, size,
> &saddr, NULL);
> +               if (ret)
> +                       return ret;
> +       }
>  
>         if (copy_all) {
> -               ret = copy_mc_to_kernel(daddr, saddr, length);
> -               return ret ? -EIO : 0;
> +               if (zero_edge)
> +                       memset(daddr, 0, size);
> +               else
> +                       ret = copy_mc_to_kernel(daddr, saddr,
> length);
> +               goto out;
>         }
>  
>         /* Copy the head part of the range */
>         if (head_off) {
> -               ret = copy_mc_to_kernel(daddr, saddr, head_off);
> -               if (ret)
> -                       return -EIO;
> +               if (zero_edge)
> +                       memset(daddr, 0, head_off);
> +               else {
> +                       ret = copy_mc_to_kernel(daddr, saddr,
> head_off);
> +                       if (ret)
> +                               return -EIO;
> +               }
>         }
>  
>         /* Copy the tail part of the range */
> @@ -1137,12 +1153,19 @@ static int dax_iomap_cow_copy(loff_t pos,
> uint64_t length, size_t align_size,
>                 loff_t tail_off = head_off + length;
>                 loff_t tail_len = pg_end - end;
>  
> -               ret = copy_mc_to_kernel(daddr + tail_off, saddr +
> tail_off,
> -                                       tail_len);
> -               if (ret)
> -                       return -EIO;
> +               if (zero_edge)
> +                       memset(daddr + tail_off, 0, tail_len);
> +               else {
> +                       ret = copy_mc_to_kernel(daddr + tail_off,
> +                                               saddr + tail_off,
> tail_len);
> +                       if (ret)
> +                               return -EIO;
> +               }
>         }
> -       return 0;
> +out:
> +       if (zero_edge)
> +               dax_flush(srcmap->dax_dev, daddr, size);
> +       return ret ? -EIO : 0;
>  }
>  
>  /*
> @@ -1241,13 +1264,10 @@ static int dax_memzero(struct iomap_iter
> *iter, loff_t pos, size_t size)
>         if (ret < 0)
>                 return ret;
>         memset(kaddr + offset, 0, size);
> -       if (srcmap->addr != iomap->addr) {
> -               ret = dax_iomap_cow_copy(pos, size, PAGE_SIZE,
> srcmap,
> -                                        kaddr);
> -               if (ret < 0)
> -                       return ret;
> -               dax_flush(iomap->dax_dev, kaddr, PAGE_SIZE);
> -       } else
> +       if (iomap->flags & IOMAP_F_SHARED)
> +               ret = dax_iomap_copy_around(pos, size, PAGE_SIZE,
> srcmap,
> +                                           kaddr);
> +       else
>                 dax_flush(iomap->dax_dev, kaddr + offset, size);
>         return ret;
>  }
> @@ -1401,8 +1421,8 @@ static loff_t dax_iomap_iter(const struct
> iomap_iter *iomi,
>                 }
>  
>                 if (cow) {
> -                       ret = dax_iomap_cow_copy(pos, length,
> PAGE_SIZE, srcmap,
> -                                                kaddr);
> +                       ret = dax_iomap_copy_around(pos, length,
> PAGE_SIZE,
> +                                                   srcmap, kaddr);
>                         if (ret)
>                                 break;
>                 }
> @@ -1547,7 +1567,7 @@ static vm_fault_t dax_fault_iter(struct
> vm_fault *vmf,
>                 struct xa_state *xas, void **entry, bool pmd)
>  {
>         const struct iomap *iomap = &iter->iomap;
> -       const struct iomap *srcmap = &iter->srcmap;
> +       const struct iomap *srcmap = iomap_iter_srcmap(iter);
>         size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
>         loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
>         bool write = iter->flags & IOMAP_WRITE;
> @@ -1578,9 +1598,8 @@ static vm_fault_t dax_fault_iter(struct
> vm_fault *vmf,
>  
>         *entry = dax_insert_entry(xas, vmf, iter, *entry, pfn,
> entry_flags);
>  
> -       if (write &&
> -           srcmap->type != IOMAP_HOLE && srcmap->addr != iomap-
> >addr) {
> -               err = dax_iomap_cow_copy(pos, size, size, srcmap,
> kaddr);
> +       if (write && iomap->flags & IOMAP_F_SHARED) {
> +               err = dax_iomap_copy_around(pos, size, size, srcmap,
> kaddr);
>                 if (err)
>                         return dax_fault_return(err);
>         }