[PATCH 04/12] rust: pin-init: rewrite `derive(Zeroable)` and `derive(MaybeZeroable)` using `syn`

Benno Lossin posted 13 patches 1 month ago
There is a newer version of this series
[PATCH 04/12] rust: pin-init: rewrite `derive(Zeroable)` and `derive(MaybeZeroable)` using `syn`
Posted by Benno Lossin 1 month ago
Rewrite the two derive macros for `Zeroable` using `syn`. One positive
side effect of this change is that tuple structs are now supported by
them. Additionally, syntax errors and the error emitted when trying to
use one of the derive macros on an `enum` are improved. Otherwise no
functional changes intended.

For example:

    #[derive(Zeroable)]
    enum Num {
        A(u32),
        B(i32),
    }

Produced this error before this commit:

    error: no rules expected keyword `enum`
     --> tests/ui/compile-fail/zeroable/enum.rs:5:1
      |
    5 | enum Num {
      | ^^^^ no rules expected this token in macro call
      |
    note: while trying to match keyword `struct`
     --> src/macros.rs
      |
      |             $vis:vis struct $name:ident
      |                      ^^^^^^

Now the error is:

    error: cannot derive `Zeroable` for an enum
     --> tests/ui/compile-fail/zeroable/enum.rs:5:1
      |
    5 | enum Num {
      | ^^^^

    error: cannot derive `Zeroable` for an enum

Signed-off-by: Benno Lossin <lossin@kernel.org>
---
 rust/pin-init/internal/src/lib.rs      |   5 +-
 rust/pin-init/internal/src/zeroable.rs | 149 ++++++++++---------------
 rust/pin-init/src/macros.rs            | 124 --------------------
 3 files changed, 65 insertions(+), 213 deletions(-)

diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs
index 4c4dc639ce82..ec593362c5ac 100644
--- a/rust/pin-init/internal/src/lib.rs
+++ b/rust/pin-init/internal/src/lib.rs
@@ -11,6 +11,7 @@
 #![allow(missing_docs)]
 
 use proc_macro::TokenStream;
+use syn::parse_macro_input;
 
 mod helpers;
 mod pin_data;
@@ -29,10 +30,10 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream {
 
 #[proc_macro_derive(Zeroable)]
 pub fn derive_zeroable(input: TokenStream) -> TokenStream {
-    zeroable::derive(input.into()).into()
+    zeroable::derive(parse_macro_input!(input as _)).into()
 }
 
 #[proc_macro_derive(MaybeZeroable)]
 pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream {
-    zeroable::maybe_derive(input.into()).into()
+    zeroable::maybe_derive(parse_macro_input!(input as _)).into()
 }
diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs
index d8a5ef3883f4..0328c3bdfceb 100644
--- a/rust/pin-init/internal/src/zeroable.rs
+++ b/rust/pin-init/internal/src/zeroable.rs
@@ -1,99 +1,74 @@
 // SPDX-License-Identifier: GPL-2.0
 
-use crate::helpers::{parse_generics, Generics};
-use proc_macro2::{TokenStream, TokenTree};
-use quote::quote;
+use proc_macro2::TokenStream;
+use quote::{quote, quote_spanned};
+use syn::{parse_quote, Data, DeriveInput, Field, Fields};
 
-pub(crate) fn parse_zeroable_derive_input(
-    input: TokenStream,
-) -> (
-    Vec<TokenTree>,
-    Vec<TokenTree>,
-    Vec<TokenTree>,
-    Option<TokenTree>,
-) {
-    let (
-        Generics {
-            impl_generics,
-            decl_generics: _,
-            ty_generics,
-        },
-        mut rest,
-    ) = parse_generics(input);
-    // This should be the body of the struct `{...}`.
-    let last = rest.pop();
-    // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
-    let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
-    // Are we inside of a generic where we want to add `Zeroable`?
-    let mut in_generic = !impl_generics.is_empty();
-    // Have we already inserted `Zeroable`?
-    let mut inserted = false;
-    // Level of `<>` nestings.
-    let mut nested = 0;
-    for tt in impl_generics {
-        match &tt {
-            // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
-            TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
-                if in_generic && !inserted {
-                    new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
-                }
-                in_generic = true;
-                inserted = false;
-                new_impl_generics.push(tt);
-            }
-            // If we find `'`, then we are entering a lifetime.
-            TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
-                in_generic = false;
-                new_impl_generics.push(tt);
-            }
-            TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
-                new_impl_generics.push(tt);
-                if in_generic {
-                    new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
-                    inserted = true;
-                }
-            }
-            TokenTree::Punct(p) if p.as_char() == '<' => {
-                nested += 1;
-                new_impl_generics.push(tt);
-            }
-            TokenTree::Punct(p) if p.as_char() == '>' => {
-                assert!(nested > 0);
-                nested -= 1;
-                new_impl_generics.push(tt);
-            }
-            _ => new_impl_generics.push(tt),
+pub(crate) fn derive(input: DeriveInput) -> TokenStream {
+    let fields = match input.data {
+        Data::Struct(data_struct) => data_struct.fields,
+        Data::Union(data_union) => Fields::Named(data_union.fields),
+        Data::Enum(data_enum) => {
+            return quote_spanned! {data_enum.enum_token.span=>
+                ::core::compile_error!("cannot derive `Zeroable` for an enum");
+            };
         }
+    };
+    let name = input.ident;
+    let mut generics = input.generics;
+    for param in generics.type_params_mut() {
+        param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
     }
-    assert_eq!(nested, 0);
-    if in_generic && !inserted {
-        new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
-    }
-    (rest, new_impl_generics, ty_generics, last)
-}
-
-pub(crate) fn derive(input: TokenStream) -> TokenStream {
-    let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
+    let (impl_gen, ty_gen, whr) = generics.split_for_impl();
+    let field_type = fields.iter().map(|field| &field.ty);
     quote! {
-        ::pin_init::__derive_zeroable!(
-            parse_input:
-                @sig(#(#rest)*),
-                @impl_generics(#(#new_impl_generics)*),
-                @ty_generics(#(#ty_generics)*),
-                @body(#last),
-        );
+        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
+        #[automatically_derived]
+        unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
+            #whr
+        {}
+        const _: () = {
+            fn assert_zeroable<T: ?::core::marker::Sized + ::pin_init::Zeroable>() {}
+            fn ensure_zeroable #impl_gen ()
+                #whr
+            {
+                #(
+                    assert_zeroable::<#field_type>();
+                )*
+            }
+        };
     }
 }
 
-pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream {
-    let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
+pub(crate) fn maybe_derive(input: DeriveInput) -> TokenStream {
+    let fields = match input.data {
+        Data::Struct(data_struct) => data_struct.fields,
+        Data::Union(data_union) => Fields::Named(data_union.fields),
+        Data::Enum(data_enum) => {
+            return quote_spanned! {data_enum.enum_token.span=>
+                compile_error!("cannot derive `Zeroable` for an enum");
+            };
+        }
+    };
+    let name = input.ident;
+    let mut generics = input.generics;
+    for param in generics.type_params_mut() {
+        param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
+    }
+    for Field { ty, .. } in fields {
+        generics
+            .make_where_clause()
+            .predicates
+            // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds`
+            // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>.
+            .push(parse_quote!(#ty: for<'__dummy> ::pin_init::Zeroable));
+    }
+    let (impl_gen, ty_gen, whr) = generics.split_for_impl();
     quote! {
-        ::pin_init::__maybe_derive_zeroable!(
-            parse_input:
-                @sig(#(#rest)*),
-                @impl_generics(#(#new_impl_generics)*),
-                @ty_generics(#(#ty_generics)*),
-                @body(#last),
-        );
+        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
+        #[automatically_derived]
+        unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
+            #whr
+        {}
     }
 }
diff --git a/rust/pin-init/src/macros.rs b/rust/pin-init/src/macros.rs
index 682c61a587a0..53ed5ce860fc 100644
--- a/rust/pin-init/src/macros.rs
+++ b/rust/pin-init/src/macros.rs
@@ -1551,127 +1551,3 @@ fn assert_zeroable<T: $crate::Zeroable>(_: *mut T) {}
         );
     };
 }
-
-#[doc(hidden)]
-#[macro_export]
-macro_rules! __derive_zeroable {
-    (parse_input:
-        @sig(
-            $(#[$($struct_attr:tt)*])*
-            $vis:vis struct $name:ident
-            $(where $($whr:tt)*)?
-        ),
-        @impl_generics($($impl_generics:tt)*),
-        @ty_generics($($ty_generics:tt)*),
-        @body({
-            $(
-                $(#[$($field_attr:tt)*])*
-                $field_vis:vis $field:ident : $field_ty:ty
-            ),* $(,)?
-        }),
-    ) => {
-        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
-        #[automatically_derived]
-        unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*>
-        where
-            $($($whr)*)?
-        {}
-        const _: () = {
-            fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {}
-            fn ensure_zeroable<$($impl_generics)*>()
-                where $($($whr)*)?
-            {
-                $(assert_zeroable::<$field_ty>();)*
-            }
-        };
-    };
-    (parse_input:
-        @sig(
-            $(#[$($struct_attr:tt)*])*
-            $vis:vis union $name:ident
-            $(where $($whr:tt)*)?
-        ),
-        @impl_generics($($impl_generics:tt)*),
-        @ty_generics($($ty_generics:tt)*),
-        @body({
-            $(
-                $(#[$($field_attr:tt)*])*
-                $field_vis:vis $field:ident : $field_ty:ty
-            ),* $(,)?
-        }),
-    ) => {
-        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
-        #[automatically_derived]
-        unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*>
-        where
-            $($($whr)*)?
-        {}
-        const _: () = {
-            fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {}
-            fn ensure_zeroable<$($impl_generics)*>()
-                where $($($whr)*)?
-            {
-                $(assert_zeroable::<$field_ty>();)*
-            }
-        };
-    };
-}
-
-#[doc(hidden)]
-#[macro_export]
-macro_rules! __maybe_derive_zeroable {
-    (parse_input:
-        @sig(
-            $(#[$($struct_attr:tt)*])*
-            $vis:vis struct $name:ident
-            $(where $($whr:tt)*)?
-        ),
-        @impl_generics($($impl_generics:tt)*),
-        @ty_generics($($ty_generics:tt)*),
-        @body({
-            $(
-                $(#[$($field_attr:tt)*])*
-                $field_vis:vis $field:ident : $field_ty:ty
-            ),* $(,)?
-        }),
-    ) => {
-        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
-        #[automatically_derived]
-        unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*>
-        where
-            $(
-                // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds`
-                // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>.
-                $field_ty: for<'__dummy> $crate::Zeroable,
-            )*
-            $($($whr)*)?
-        {}
-    };
-    (parse_input:
-        @sig(
-            $(#[$($struct_attr:tt)*])*
-            $vis:vis union $name:ident
-            $(where $($whr:tt)*)?
-        ),
-        @impl_generics($($impl_generics:tt)*),
-        @ty_generics($($ty_generics:tt)*),
-        @body({
-            $(
-                $(#[$($field_attr:tt)*])*
-                $field_vis:vis $field:ident : $field_ty:ty
-            ),* $(,)?
-        }),
-    ) => {
-        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
-        #[automatically_derived]
-        unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*>
-        where
-            $(
-                // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds`
-                // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>.
-                $field_ty: for<'__dummy> $crate::Zeroable,
-            )*
-            $($($whr)*)?
-        {}
-    };
-}
-- 
2.51.2
Re: [PATCH 04/12] rust: pin-init: rewrite `derive(Zeroable)` and `derive(MaybeZeroable)` using `syn`
Posted by Gary Guo 1 month ago
On Thu Jan 8, 2026 at 1:50 PM GMT, Benno Lossin wrote:
> Rewrite the two derive macros for `Zeroable` using `syn`. One positive
> side effect of this change is that tuple structs are now supported by
> them. Additionally, syntax errors and the error emitted when trying to
> use one of the derive macros on an `enum` are improved. Otherwise no
> functional changes intended.
>
> For example:
>
>     #[derive(Zeroable)]
>     enum Num {
>         A(u32),
>         B(i32),
>     }
>
> Produced this error before this commit:
>
>     error: no rules expected keyword `enum`
>      --> tests/ui/compile-fail/zeroable/enum.rs:5:1
>       |
>     5 | enum Num {
>       | ^^^^ no rules expected this token in macro call
>       |
>     note: while trying to match keyword `struct`
>      --> src/macros.rs
>       |
>       |             $vis:vis struct $name:ident
>       |                      ^^^^^^
>
> Now the error is:
>
>     error: cannot derive `Zeroable` for an enum
>      --> tests/ui/compile-fail/zeroable/enum.rs:5:1
>       |
>     5 | enum Num {
>       | ^^^^
>
>     error: cannot derive `Zeroable` for an enum
>
> Signed-off-by: Benno Lossin <lossin@kernel.org>
> ---
>  rust/pin-init/internal/src/lib.rs      |   5 +-
>  rust/pin-init/internal/src/zeroable.rs | 149 ++++++++++---------------
>  rust/pin-init/src/macros.rs            | 124 --------------------
>  3 files changed, 65 insertions(+), 213 deletions(-)

Yay! the updated code is much more readable.

>
> diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs
> index 4c4dc639ce82..ec593362c5ac 100644
> --- a/rust/pin-init/internal/src/lib.rs
> +++ b/rust/pin-init/internal/src/lib.rs
> @@ -11,6 +11,7 @@
>  #![allow(missing_docs)]
>  
>  use proc_macro::TokenStream;
> +use syn::parse_macro_input;
>  
>  mod helpers;
>  mod pin_data;
> @@ -29,10 +30,10 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream {
>  
>  #[proc_macro_derive(Zeroable)]
>  pub fn derive_zeroable(input: TokenStream) -> TokenStream {
> -    zeroable::derive(input.into()).into()
> +    zeroable::derive(parse_macro_input!(input as _)).into()

This can just be

    zeroable::derive(parse_macro_input!(input)).into()

same for the below.

>  }
>  
>  #[proc_macro_derive(MaybeZeroable)]
>  pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream {
> -    zeroable::maybe_derive(input.into()).into()
> +    zeroable::maybe_derive(parse_macro_input!(input as _)).into()
>  }
> diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs
> index d8a5ef3883f4..0328c3bdfceb 100644
> --- a/rust/pin-init/internal/src/zeroable.rs
> +++ b/rust/pin-init/internal/src/zeroable.rs
> @@ -1,99 +1,74 @@
>  // SPDX-License-Identifier: GPL-2.0
>  
> -use crate::helpers::{parse_generics, Generics};
> -use proc_macro2::{TokenStream, TokenTree};
> -use quote::quote;
> +use proc_macro2::TokenStream;
> +use quote::{quote, quote_spanned};
> +use syn::{parse_quote, Data, DeriveInput, Field, Fields};
>  
> -pub(crate) fn parse_zeroable_derive_input(
> -    input: TokenStream,
> -) -> (
> -    Vec<TokenTree>,
> -    Vec<TokenTree>,
> -    Vec<TokenTree>,
> -    Option<TokenTree>,
> -) {
> -    let (
> -        Generics {
> -            impl_generics,
> -            decl_generics: _,
> -            ty_generics,
> -        },
> -        mut rest,
> -    ) = parse_generics(input);
> -    // This should be the body of the struct `{...}`.
> -    let last = rest.pop();
> -    // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
> -    let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
> -    // Are we inside of a generic where we want to add `Zeroable`?
> -    let mut in_generic = !impl_generics.is_empty();
> -    // Have we already inserted `Zeroable`?
> -    let mut inserted = false;
> -    // Level of `<>` nestings.
> -    let mut nested = 0;
> -    for tt in impl_generics {
> -        match &tt {
> -            // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
> -            TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
> -                if in_generic && !inserted {
> -                    new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
> -                }
> -                in_generic = true;
> -                inserted = false;
> -                new_impl_generics.push(tt);
> -            }
> -            // If we find `'`, then we are entering a lifetime.
> -            TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
> -                in_generic = false;
> -                new_impl_generics.push(tt);
> -            }
> -            TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
> -                new_impl_generics.push(tt);
> -                if in_generic {
> -                    new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
> -                    inserted = true;
> -                }
> -            }
> -            TokenTree::Punct(p) if p.as_char() == '<' => {
> -                nested += 1;
> -                new_impl_generics.push(tt);
> -            }
> -            TokenTree::Punct(p) if p.as_char() == '>' => {
> -                assert!(nested > 0);
> -                nested -= 1;
> -                new_impl_generics.push(tt);
> -            }
> -            _ => new_impl_generics.push(tt),
> +pub(crate) fn derive(input: DeriveInput) -> TokenStream {
> +    let fields = match input.data {
> +        Data::Struct(data_struct) => data_struct.fields,
> +        Data::Union(data_union) => Fields::Named(data_union.fields),
> +        Data::Enum(data_enum) => {
> +            return quote_spanned! {data_enum.enum_token.span=>
> +                ::core::compile_error!("cannot derive `Zeroable` for an enum");
> +            };

Even if it's currently just one place that could error, I would still probably
use `syn`'s Error type as it's easier to extend in the future.

Best,
Gary

>          }
> +    };
> +    let name = input.ident;
> +    let mut generics = input.generics;
> +    for param in generics.type_params_mut() {
> +        param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
>      }
> -    assert_eq!(nested, 0);
> -    if in_generic && !inserted {
> -        new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
> -    }
> -    (rest, new_impl_generics, ty_generics, last)
> -}
> -