Forwarded: [PATCH] mm/userfaultfd: re-validate vma in mfill_atomic() loop under CONFIG_PER_VMA_LOCK

syzbot posted 1 patch 3 weeks, 1 day ago
mm/userfaultfd.c | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
Forwarded: [PATCH] mm/userfaultfd: re-validate vma in mfill_atomic() loop under CONFIG_PER_VMA_LOCK
Posted by syzbot 3 weeks, 1 day ago
For archival purposes, forwarding an incoming command email to
linux-kernel@vger.kernel.org, syzkaller-bugs@googlegroups.com.

***

Subject: [PATCH] mm/userfaultfd: re-validate vma in mfill_atomic() loop under CONFIG_PER_VMA_LOCK
Author: kartikey406@gmail.com

#syz test: git://git.kernel.org/pub/scm/linux/kernel/git/next/linux-next.git master


Under CONFIG_PER_VMA_LOCK, mfill_atomic() holds only a per-VMA read
lock (vma_start_read) across its page-by-page copy loop. Unlike
mmap_read_lock, this does not prevent a concurrent mmap_write_lock()
from splitting the vma mid-loop via UFFDIO_UNREGISTER.

When the vma is split, vm_end of state.vma is shrunk in place. On the
next iteration, mfill_atomic_install_pte() calls folio_add_new_anon_rmap()
with state.dst_addr >= vma->vm_end, triggering the sanity check:

  address < vma->vm_start || address + (nr << 12) > vma->vm_end
  WARNING: mm/rmap.c:1682 folio_add_new_anon_rmap+0x5fe/0x14b0

Fix this by checking on each loop iteration whether state.dst_addr
has fallen outside state.vma. If so, release the stale vma, update
dst_start and len to reflect the current position, and re-lookup the
vma via mfill_get_vma().

Reported-by: syzbot+e24a2e34fad0efbac047@syzkaller.appspotmail.com
Closes: https://syzkaller.appspot.com/bug?extid=e24a2e34fad0efbac047
Signed-off-by: Deepanshu Kartikey <Kartikey406@gmail.com>
---
 mm/userfaultfd.c | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 9ffc80d0a51b..519be02fad38 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -910,8 +910,17 @@ static __always_inline ssize_t mfill_atomic(struct userfaultfd_ctx *ctx,
 
 	while (state.src_addr < src_start + len) {
 		VM_WARN_ON_ONCE(state.dst_addr >= dst_start + len);
+		if (state.dst_addr < state.vma->vm_start ||
+		    state.dst_addr >= state.vma->vm_end) {
+			mfill_put_vma(&state);
+			state.dst_start = state.dst_addr;
+			state.len = dst_start + len - state.dst_addr;
+			err = mfill_get_vma(&state);
+			if (err)
+				break;
+		}
 
-		err = mfill_get_pmd(&state);
+		err = mfill_get_pmd(&state);	
 		if (err)
 			break;
 
-- 
2.43.0