[PATCH 2/3] rust: macros: add IoctlCommand derive macro

Josef Zoller posted 3 patches 1 month, 2 weeks ago
[PATCH 2/3] rust: macros: add IoctlCommand derive macro
Posted by Josef Zoller 1 month, 2 weeks ago
Provide a macro that derives the `IoctlCommand` trait for simple enums
by converting every variant into a unique command.

The macro can be instructed to use a specific letter or integer as the
code. Each variant is then assigned a consecutive number starting from
0 or a given value. The type of the command, i.e. if it is a read or
write command, is inferred from the variant's associated data: if it
has no data or only an integer, it is neither read nor write, if it has
a UserSliceReader it is a write command, if it has a UserSliceWriter it
is a read command, and if it just has a UserSlice it is a read-write
command. The code and the variant's number and type are then combined
to parse the command from the user-provided cmd and arg values.

Signed-off-by: Josef Zoller <josef@walterzollerpiano.com>
---
 rust/kernel/ioctl.rs     | 190 ++++++++++++++++++++++++++++++++++++++++++++
 rust/kernel/prelude.rs   |   2 +-
 rust/macros/ioctl_cmd.rs | 202 +++++++++++++++++++++++++++++++++++++++++++++++
 rust/macros/lib.rs       |  21 +++++
 4 files changed, 414 insertions(+), 1 deletion(-)

diff --git a/rust/kernel/ioctl.rs b/rust/kernel/ioctl.rs
index 03359ab28495b94d98d53db2115bbbcc520c18a3..f6af9c10c0b244b8d8183cf70b4ef5ce9233c935 100644
--- a/rust/kernel/ioctl.rs
+++ b/rust/kernel/ioctl.rs
@@ -73,6 +73,22 @@ pub const fn _IOC_SIZE(nr: u32) -> usize {
 }
 
 /// Types implementing this trait can be used to parse ioctl commands.
+///
+/// Normally, this trait is derived for a command enum.
+///
+/// # Example
+///
+/// ```
+/// #[derive(IoctlCommand)]
+/// #[ioctl(code = 0x18, start_num = 0)]
+/// enum Command {
+///     NoReadWrite,                 // No read or write access.
+///     NoReadWriteButTakesArg(u64), // No read or write access, but takes an argument.
+///     ReadOnly(UserSliceWriter),   // We write data for the user to read.
+///     WriteOnly(UserSliceReader),  // We read data that the user wrote.
+///     WriteAndRead(UserSlice),     // We read data from the user and then write data to the user.
+/// }
+/// ```
 #[vtable]
 pub trait IoctlCommand: Sized + Send + Sync + 'static {
     /// The error type returned by the parse functions.
@@ -114,3 +130,177 @@ fn parse(_cmd: ffi::c_uint, _arg: ffi::c_ulong) -> Result<Self> {
         Err(ENOTTY)
     }
 }
+
+/// Support macro for deriving the `IoctlCommand` trait.
+#[doc(hidden)]
+#[macro_export]
+macro_rules! __derive_ioctl_cmd {
+    (parse_input:
+        @enum_name($enum_name:ident),
+        @code($code:literal),
+        @variants(
+            $(
+                @variant($i:literal, $variant:ident, $arg_type:tt),
+            )*
+        )
+    ) => {
+        #[automatically_derived]
+        impl $crate::ioctl::IoctlCommand for $enum_name {
+            type Err = $crate::error::Error;
+
+            const USE_VTABLE_ATTR: () = ();
+
+            const HAS_PARSE: bool = true;
+
+            fn parse(
+                cmd: ::core::ffi::c_uint,
+                arg: ::core::ffi::c_ulong,
+            ) -> ::core::result::Result<Self, Self::Err> {
+                let ty = $crate::ioctl::_IOC_TYPE(cmd) as u8;
+
+                if ty != $code {
+                    return Err($crate::error::code::ENOTTY);
+                }
+
+                let nr = $crate::ioctl::_IOC_NR(cmd) as u8;
+                let dir = $crate::ioctl::_IOC_DIR(cmd);
+                let size = $crate::ioctl::_IOC_SIZE(cmd);
+
+                // Make sure we don't get unused parameter warnings
+                let _ = arg;
+
+                match (nr, dir, size) {
+                    $(
+                        ::kernel::__derive_ioctl_cmd!(
+                            match_pattern:
+                                @variant($i, $arg_type)
+                        ) => ::kernel::__derive_ioctl_cmd!(
+                            match_body:
+                                @dir(dir),
+                                @size(size),
+                                @arg(arg),
+                                @variant($variant, $arg_type)
+                        ),
+                    )*
+                    _ => Err($crate::error::code::ENOTTY),
+                }
+            }
+        }
+    };
+    (match_pattern:
+        @variant($i:literal, None)
+    ) => {
+        ($i, $crate::uapi::_IOC_NONE, 0)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, None)
+    ) => {
+        Ok(Self::$variant)
+    };
+    (match_pattern:
+        @variant($i:literal, u64)
+    ) => {
+        ($i, $crate::uapi::_IOC_NONE, 0)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, u64)
+    ) => {
+        Ok(Self::$variant($arg))
+    };
+    (match_pattern:
+        @variant($i:literal, UserSliceWriter)
+    ) => {
+        ($i, $crate::uapi::_IOC_READ, _)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, UserSliceWriter)
+    ) => {
+        {
+            let user_writer = $crate::uaccess::UserSlice::new(
+                $arg as $crate::uaccess::UserPtr,
+                $size
+            )
+            .writer();
+
+            Ok(Self::$variant(user_writer))
+        }
+    };
+    (match_pattern:
+        @variant($i:literal, UserSliceReader)
+    ) => {
+        ($i, $crate::uapi::_IOC_WRITE, _)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, UserSliceReader)
+    ) => {
+        {
+            let user_reader = $crate::uaccess::UserSlice::new(
+                $arg as $crate::uaccess::UserPtr,
+                $size
+            )
+            .reader();
+
+            Ok(Self::$variant(user_reader))
+        }
+    };
+    (match_pattern:
+        @variant($i:literal, UserSlice)
+    ) => {
+        ($i, _, _)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, UserSlice)
+    ) => {
+        // Unfortunately, we cannot just do a match guard
+        if $dir != $crate::uapi::_IOC_READ | $crate::uapi::_IOC_WRITE {
+            Err($crate::error::code::ENOTTY)
+        } else {
+            let user_slice = $crate::uaccess::UserSlice::new(
+                $arg as $crate::uaccess::UserPtr,
+                $size
+            );
+
+            Ok(Self::$variant(user_slice))
+        }
+    };
+    (match_pattern:
+        @variant($i:literal, $arg_type:tt)
+    ) => {
+        ($i, _, _)
+    };
+    (match_body:
+        @dir($dir:ident),
+        @size($size:ident),
+        @arg($arg:ident),
+        @variant($variant:ident, $arg_type:tt)
+    ) => {
+        {
+            // We have an unsupported argument type
+            const _: () = ::core::assert!(
+                false,
+                ::core::concat!(
+                    "Invalid argument type for ioctl command ",
+                    stringify!($variant),
+                    ": ",
+                    stringify!($arg_type),
+                )
+            );
+            ::core::unreachable!()
+        }
+    };
+}
diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs
index 4571daec0961bb34fb6956a4e9eda8445954b719..1277d1ec5a476d3e115f6b2ba432b0fbe28941a2 100644
--- a/rust/kernel/prelude.rs
+++ b/rust/kernel/prelude.rs
@@ -20,7 +20,7 @@
 pub use alloc::{boxed::Box, vec::Vec};
 
 #[doc(no_inline)]
-pub use macros::{module, pin_data, pinned_drop, vtable, Zeroable};
+pub use macros::{module, pin_data, pinned_drop, vtable, IoctlCommand, Zeroable};
 
 pub use super::build_assert;
 
diff --git a/rust/macros/ioctl_cmd.rs b/rust/macros/ioctl_cmd.rs
new file mode 100644
index 0000000000000000000000000000000000000000..366a9b1f7ba70ba764b0d78cb32d82125bc7b854
--- /dev/null
+++ b/rust/macros/ioctl_cmd.rs
@@ -0,0 +1,202 @@
+// SPDX-License-Identifier: GPL-2.0
+
+use proc_macro::{token_stream, Delimiter, Literal, TokenStream, TokenTree};
+
+fn expect_punct(input: &mut impl Iterator<Item = TokenTree>, expected: char, reason: &str) {
+    let Some(TokenTree::Punct(punct)) = input.next() else {
+        panic!("expected '{expected}' {reason}");
+    };
+
+    if punct.as_char() != expected {
+        panic!("expected '{expected}' {reason}");
+    }
+}
+
+fn expect_ident(input: &mut impl Iterator<Item = TokenTree>, expected: &str, reason: &str) {
+    let Some(TokenTree::Ident(ident)) = input.next() else {
+        panic!("expected '{expected}' {reason}");
+    };
+
+    if ident.to_string() != expected {
+        panic!("expected '{expected}' {reason}");
+    }
+}
+
+fn expect_group(
+    input: &mut impl Iterator<Item = TokenTree>,
+    expected: Delimiter,
+    reason: &str,
+) -> token_stream::IntoIter {
+    let Some(TokenTree::Group(group)) = input.next() else {
+        panic!("expected group {reason}");
+    };
+
+    if group.delimiter() != expected {
+        panic!("expected group {reason}");
+    }
+
+    group.stream().into_iter()
+}
+
+fn parse_attribute(input: &mut impl Iterator<Item = TokenTree>) -> (u8, u8) {
+    expect_punct(input, '#', "to start attribute");
+
+    let mut stream = expect_group(input, Delimiter::Bracket, "as attribute body");
+
+    expect_ident(&mut stream, "ioctl", "as attribute name");
+
+    let mut inner_stream = expect_group(
+        &mut stream,
+        Delimiter::Parenthesis,
+        "as attribute arguments",
+    );
+
+    expect_ident(&mut inner_stream, "code", "as ioctl attribute field");
+    expect_punct(&mut inner_stream, '=', "in ioctl attribute field");
+
+    let Some(TokenTree::Literal(lit)) = inner_stream.next() else {
+        panic!("expected ioctl attribute code value");
+    };
+
+    let lit_str = lit.to_string();
+    let code = if lit_str.starts_with("b'") {
+        lit_str
+            .chars()
+            .nth(2)
+            .expect("expected ioctl attribute code value") as u8
+    } else if let Some(hex) = lit_str.strip_prefix("0x") {
+        u8::from_str_radix(hex, 16).expect("expected ioctl attribute code value")
+    } else {
+        lit_str
+            .parse()
+            .expect("expected ioctl attribute code value")
+    };
+
+    let start_num = if let Some(tree) = inner_stream.next() {
+        if !matches!(tree, TokenTree::Punct(punct) if punct.as_char() == ',') {
+            panic!("expected ioctl attribute comma");
+        }
+
+        expect_ident(&mut inner_stream, "start_num", "as ioctl attribute field");
+        expect_punct(&mut inner_stream, '=', "in ioctl attribute field");
+
+        let Some(TokenTree::Literal(lit)) = inner_stream.next() else {
+            panic!("expected ioctl attribute start number value");
+        };
+
+        lit.to_string()
+            .parse()
+            .expect("expected ioctl attribute start number value")
+    } else {
+        0
+    };
+
+    assert!(
+        inner_stream.next().is_none(),
+        "unexpected token in ioctl attribute"
+    );
+    assert!(
+        stream.next().is_none(),
+        "unexpected token in ioctl attribute"
+    );
+
+    (code, start_num)
+}
+
+fn parse_enum_def(input: &mut impl Iterator<Item = TokenTree>) -> TokenTree {
+    expect_ident(input, "enum", "to start enum definition");
+
+    let Some(ident @ TokenTree::Ident(_)) = input.next() else {
+        panic!("expected enum name");
+    };
+
+    ident
+}
+
+fn parse_enum_body(
+    input: &mut impl Iterator<Item = TokenTree>,
+) -> Vec<(TokenTree, Option<TokenTree>)> {
+    let mut stream = expect_group(input, Delimiter::Brace, "as enum body").peekable();
+
+    let mut variants = Vec::new();
+
+    while let Some(variant) = stream.next_if(|t| matches!(t, TokenTree::Ident(_))) {
+        let arg_type = if let Some(TokenTree::Group(group)) =
+            stream.next_if(|t| matches!(t, TokenTree::Group(_)))
+        {
+            if group.delimiter() != Delimiter::Parenthesis {
+                panic!("expected group");
+            }
+
+            let mut inner_stream = group.stream().into_iter();
+
+            let arg_type = if let Some(ident @ TokenTree::Ident(_)) = inner_stream.next() {
+                ident
+            } else {
+                panic!("expected argument type")
+            };
+
+            assert!(
+                inner_stream.next().is_none(),
+                "unexpected token in enum variant"
+            );
+
+            Some(arg_type)
+        } else {
+            None
+        };
+
+        variants.push((variant, arg_type));
+
+        if stream
+            .next_if(|t| matches!(t, TokenTree::Punct(punct) if punct.as_char() == ','))
+            .is_none()
+        {
+            break;
+        }
+    }
+
+    assert!(stream.next().is_none(), "unexpected token in enum body");
+
+    variants
+}
+
+pub(crate) fn derive(input: TokenStream) -> TokenStream {
+    let mut input = input.into_iter();
+
+    let (code, start_num) = parse_attribute(&mut input);
+    let enum_name = parse_enum_def(&mut input);
+    let variants = parse_enum_body(&mut input);
+
+    assert!(input.next().is_none(), "unexpected token in ioctl_cmd");
+
+    let code = TokenTree::from(Literal::u8_suffixed(code));
+
+    let variants = variants
+        .into_iter()
+        .enumerate()
+        .map(|(i, (variant, arg_type))| {
+            let i = i as u8 + start_num;
+
+            let i = TokenTree::from(Literal::u8_suffixed(i));
+
+            if let Some(arg_type) = arg_type {
+                quote! {
+                    @variant(#i, #variant, #arg_type),
+                }
+            } else {
+                quote! {
+                    @variant(#i, #variant, None),
+                }
+            }
+        });
+
+    quote! {
+        ::kernel::__derive_ioctl_cmd!(
+            parse_input:
+                @enum_name(#enum_name),
+                @code(#code),
+                @variants(#(#variants)*)
+        );
+    }
+}
diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs
index a626b1145e5c4ff00692e9d4e11fdb93500db1a8..5a33ed69b5b0b64f6720fb54e18056af9b2f7a00 100644
--- a/rust/macros/lib.rs
+++ b/rust/macros/lib.rs
@@ -10,6 +10,7 @@
 mod quote;
 mod concat_idents;
 mod helpers;
+mod ioctl_cmd;
 mod module;
 mod paste;
 mod pin_data;
@@ -412,6 +413,26 @@ pub fn paste(input: TokenStream) -> TokenStream {
     tokens.into_iter().collect()
 }
 
+/// Derives the [`IoctlCommand`] trait for the given enum.
+///
+/// # Example
+///
+/// ```
+/// #[derive(IoctlCommand)]
+/// #[ioctl(code = 0x18, start_num = 0)]
+/// enum Command {
+///     NoReadWrite,                 // No read or write access.
+///     NoReadWriteButTakesArg(u64), // No read or write access, but takes an argument.
+///     ReadOnly(UserSliceWriter),   // We write data for the user to read.
+///     WriteOnly(UserSliceReader),  // We read data that the user wrote.
+///     WriteAndRead(UserSlice),     // We read data from the user and then write data to the user.
+/// }
+/// ```
+#[proc_macro_derive(IoctlCommand, attributes(ioctl))]
+pub fn derive_ioctl_cmd(input: TokenStream) -> TokenStream {
+    ioctl_cmd::derive(input)
+}
+
 /// Derives the [`Zeroable`] trait for the given struct.
 ///
 /// This can only be used for structs where every field implements the [`Zeroable`] trait.

-- 
2.47.0