[PATCH v6 6/9] rust: io: use generic read/write accessors for primitive accesses

Alexandre Courbot posted 9 patches 1 month, 2 weeks ago
There is a newer version of this series
[PATCH v6 6/9] rust: io: use generic read/write accessors for primitive accesses
Posted by Alexandre Courbot 1 month, 2 weeks ago
By providing the required `IoRef` implementation on `usize`, we can
leverage the generic accessors and reduce the number of unsafe blocks in
the module.

Signed-off-by: Alexandre Courbot <acourbot@nvidia.com>
---
 rust/kernel/io.rs | 103 +++++++++++++++++++-----------------------------------
 1 file changed, 35 insertions(+), 68 deletions(-)

diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs
index 6da8593f7858..053c6385842a 100644
--- a/rust/kernel/io.rs
+++ b/rust/kernel/io.rs
@@ -277,6 +277,25 @@ fn try_init_default<F, E>(self, f: F) -> Result<IoWrite<T, Self>, E>
     }
 }
 
+/// Implements [`IoRef<$primitive>`] for [`usize`], allowing to use `usize` as a parameter of
+/// [`Io::read`] and [`Io::write`].
+macro_rules! impl_usize_ioref {
+    ($($ty:ty),*) => {
+        $(
+            impl IoRef<$ty> for usize {
+                type IoType = $ty;
+
+                fn offset(self) -> usize {
+                    self
+                }
+            }
+        )*
+    }
+}
+
+// Provide the ability to read any primitive type from a [`usize`].
+impl_usize_ioref!(u8, u16, u32, u64);
+
 /// A pending I/O write operation, bundling a value with the [`IoRef`] it should be written to.
 ///
 /// Created by [`IoRef::set`], [`IoRef::zeroed`], [`IoRef::default`], [`IoRef::init`], or
@@ -371,10 +390,7 @@ fn try_read8(&self, offset: usize) -> Result<u8>
     where
         Self: IoCapable<u8>,
     {
-        let address = self.io_addr::<u8>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        Ok(unsafe { self.io_read(address) })
+        self.try_read(offset)
     }
 
     /// Fallible 16-bit read with runtime bounds check.
@@ -383,10 +399,7 @@ fn try_read16(&self, offset: usize) -> Result<u16>
     where
         Self: IoCapable<u16>,
     {
-        let address = self.io_addr::<u16>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        Ok(unsafe { self.io_read(address) })
+        self.try_read(offset)
     }
 
     /// Fallible 32-bit read with runtime bounds check.
@@ -395,10 +408,7 @@ fn try_read32(&self, offset: usize) -> Result<u32>
     where
         Self: IoCapable<u32>,
     {
-        let address = self.io_addr::<u32>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        Ok(unsafe { self.io_read(address) })
+        self.try_read(offset)
     }
 
     /// Fallible 64-bit read with runtime bounds check.
@@ -407,10 +417,7 @@ fn try_read64(&self, offset: usize) -> Result<u64>
     where
         Self: IoCapable<u64>,
     {
-        let address = self.io_addr::<u64>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        Ok(unsafe { self.io_read(address) })
+        self.try_read(offset)
     }
 
     /// Fallible 8-bit write with runtime bounds check.
@@ -419,11 +426,7 @@ fn try_write8(&self, value: u8, offset: usize) -> Result
     where
         Self: IoCapable<u8>,
     {
-        let address = self.io_addr::<u8>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        unsafe { self.io_write(value, address) };
-        Ok(())
+        self.try_write(offset.set(value))
     }
 
     /// Fallible 16-bit write with runtime bounds check.
@@ -432,11 +435,7 @@ fn try_write16(&self, value: u16, offset: usize) -> Result
     where
         Self: IoCapable<u16>,
     {
-        let address = self.io_addr::<u16>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        unsafe { self.io_write(value, address) };
-        Ok(())
+        self.try_write(offset.set(value))
     }
 
     /// Fallible 32-bit write with runtime bounds check.
@@ -445,11 +444,7 @@ fn try_write32(&self, value: u32, offset: usize) -> Result
     where
         Self: IoCapable<u32>,
     {
-        let address = self.io_addr::<u32>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        unsafe { self.io_write(value, address) };
-        Ok(())
+        self.try_write(offset.set(value))
     }
 
     /// Fallible 64-bit write with runtime bounds check.
@@ -458,11 +453,7 @@ fn try_write64(&self, value: u64, offset: usize) -> Result
     where
         Self: IoCapable<u64>,
     {
-        let address = self.io_addr::<u64>(offset)?;
-
-        // SAFETY: `address` has been validated by `io_addr`.
-        unsafe { self.io_write(value, address) };
-        Ok(())
+        self.try_write(offset.set(value))
     }
 
     /// Infallible 8-bit read with compile-time bounds check.
@@ -471,10 +462,7 @@ fn read8(&self, offset: usize) -> u8
     where
         Self: IoKnownSize + IoCapable<u8>,
     {
-        let address = self.io_addr_assert::<u8>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_read(address) }
+        self.read(offset)
     }
 
     /// Infallible 16-bit read with compile-time bounds check.
@@ -483,10 +471,7 @@ fn read16(&self, offset: usize) -> u16
     where
         Self: IoKnownSize + IoCapable<u16>,
     {
-        let address = self.io_addr_assert::<u16>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_read(address) }
+        self.read(offset)
     }
 
     /// Infallible 32-bit read with compile-time bounds check.
@@ -495,10 +480,7 @@ fn read32(&self, offset: usize) -> u32
     where
         Self: IoKnownSize + IoCapable<u32>,
     {
-        let address = self.io_addr_assert::<u32>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_read(address) }
+        self.read(offset)
     }
 
     /// Infallible 64-bit read with compile-time bounds check.
@@ -507,10 +489,7 @@ fn read64(&self, offset: usize) -> u64
     where
         Self: IoKnownSize + IoCapable<u64>,
     {
-        let address = self.io_addr_assert::<u64>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_read(address) }
+        self.read(offset)
     }
 
     /// Infallible 8-bit write with compile-time bounds check.
@@ -519,10 +498,7 @@ fn write8(&self, value: u8, offset: usize)
     where
         Self: IoKnownSize + IoCapable<u8>,
     {
-        let address = self.io_addr_assert::<u8>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_write(value, address) }
+        self.write(offset.set(value))
     }
 
     /// Infallible 16-bit write with compile-time bounds check.
@@ -531,10 +507,7 @@ fn write16(&self, value: u16, offset: usize)
     where
         Self: IoKnownSize + IoCapable<u16>,
     {
-        let address = self.io_addr_assert::<u16>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_write(value, address) }
+        self.write(offset.set(value))
     }
 
     /// Infallible 32-bit write with compile-time bounds check.
@@ -543,10 +516,7 @@ fn write32(&self, value: u32, offset: usize)
     where
         Self: IoKnownSize + IoCapable<u32>,
     {
-        let address = self.io_addr_assert::<u32>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_write(value, address) }
+        self.write(offset.set(value))
     }
 
     /// Infallible 64-bit write with compile-time bounds check.
@@ -555,10 +525,7 @@ fn write64(&self, value: u64, offset: usize)
     where
         Self: IoKnownSize + IoCapable<u64>,
     {
-        let address = self.io_addr_assert::<u64>(offset);
-
-        // SAFETY: `address` has been validated by `io_addr_assert`.
-        unsafe { self.io_write(value, address) }
+        self.write(offset.set(value))
     }
 
     /// Generic fallible read with runtime bounds check.

-- 
2.53.0