[PATCH v2 05/14] rust: hrtimer: implement `TimerPointer` for `Arc`

Andreas Hindborg posted 14 patches 2 months, 1 week ago
There is a newer version of this series
[PATCH v2 05/14] rust: hrtimer: implement `TimerPointer` for `Arc`
Posted by Andreas Hindborg 2 months, 1 week ago
This patch allows the use of intrusive `hrtimer` fields in structs that are
managed by an `Arc`.

Signed-off-by: Andreas Hindborg <a.hindborg@kernel.org>
---
 rust/kernel/hrtimer.rs     | 102 ++++++++++++++++++++++++++++++++++++-
 rust/kernel/hrtimer/arc.rs |  87 +++++++++++++++++++++++++++++++
 2 files changed, 188 insertions(+), 1 deletion(-)
 create mode 100644 rust/kernel/hrtimer/arc.rs

diff --git a/rust/kernel/hrtimer.rs b/rust/kernel/hrtimer.rs
index 5c92afd8eb2c..fd1520ba9fba 100644
--- a/rust/kernel/hrtimer.rs
+++ b/rust/kernel/hrtimer.rs
@@ -4,6 +4,64 @@
 //!
 //! Allows scheduling timer callbacks without doing allocations at the time of
 //! scheduling. For now, only one timer per type is allowed.
+//!
+//! # Example
+//!
+//! ```
+//! use kernel::{
+//!     hrtimer::{Timer, TimerCallback, TimerPointer},
+//!     impl_has_timer, new_condvar, new_mutex,
+//!     prelude::*,
+//!     sync::{Arc, CondVar, Mutex},
+//!     time::Ktime,
+//! };
+//!
+//! #[pin_data]
+//! struct ArcIntrusiveTimer {
+//!     #[pin]
+//!     timer: Timer<Self>,
+//!     #[pin]
+//!     flag: Mutex<bool>,
+//!     #[pin]
+//!     cond: CondVar,
+//! }
+//!
+//! impl ArcIntrusiveTimer {
+//!     fn new() -> impl PinInit<Self, kernel::error::Error> {
+//!         try_pin_init!(Self {
+//!             timer <- Timer::new(),
+//!             flag <- new_mutex!(false),
+//!             cond <- new_condvar!(),
+//!         })
+//!     }
+//! }
+//!
+//! impl TimerCallback for ArcIntrusiveTimer {
+//!     type CallbackTarget<'a> = Arc<Self>;
+//!
+//!     fn run(this: Self::CallbackTarget<'_>) {
+//!         pr_info!("Timer called\n");
+//!         *this.flag.lock() = true;
+//!         this.cond.notify_all();
+//!     }
+//! }
+//!
+//! impl_has_timer! {
+//!     impl HasTimer<Self> for ArcIntrusiveTimer { self.timer }
+//! }
+//!
+//!
+//! let has_timer = Arc::pin_init(ArcIntrusiveTimer::new(), GFP_KERNEL)?;
+//! let _handle = has_timer.clone().schedule(Ktime::from_ns(200_000_000));
+//! let mut guard = has_timer.flag.lock();
+//!
+//! while !*guard {
+//!     has_timer.cond.wait(&mut guard);
+//! }
+//!
+//! pr_info!("Flag raised\n");
+//! # Ok::<(), kernel::error::Error>(())
+//! ```
 
 use crate::{init::PinInit, prelude::*, time::Ktime, types::Opaque};
 use core::marker::PhantomData;
@@ -72,6 +130,25 @@ unsafe fn raw_get(ptr: *const Self) -> *mut bindings::hrtimer {
         // allocation of at least the size of `Self`.
         unsafe { Opaque::raw_get(core::ptr::addr_of!((*ptr).timer)) }
     }
+
+    /// Cancel an initialized and potentially armed timer.
+    ///
+    /// If the timer handler is running, this will block until the handler is
+    /// finished.
+    ///
+    /// # Safety
+    ///
+    /// `self_ptr` must point to a valid `Self`.
+    unsafe fn raw_cancel(self_ptr: *const Self) -> bool {
+        // SAFETY: timer_ptr points to an allocation of at least `Timer` size.
+        let c_timer_ptr = unsafe { Timer::raw_get(self_ptr) };
+
+        // If handler is running, this will wait for handler to finish before
+        // returning.
+        // SAFETY: `c_timer_ptr` is initialized and valid. Synchronization is
+        // handled on C side.
+        unsafe { bindings::hrtimer_cancel(c_timer_ptr) != 0 }
+    }
 }
 
 /// Implemented by pointer types that point to structs that embed a [`Timer`].
@@ -139,7 +216,11 @@ fn run(this: Self::CallbackTarget<'_>)
 /// When dropped, the timer represented by this handle must be cancelled, if it
 /// is armed. If the timer handler is running when the handle is dropped, the
 /// drop method must wait for the handler to finish before returning.
-pub unsafe trait TimerHandle {}
+pub unsafe trait TimerHandle {
+    /// Cancel the timer, if it is armed. If the timer handler is running, block
+    /// till the handler has finished.
+    fn cancel(&mut self) -> bool;
+}
 
 /// Implemented by structs that contain timer nodes.
 ///
@@ -196,6 +277,23 @@ unsafe fn c_timer_ptr(self_ptr: *const Self) -> *const bindings::hrtimer {
         // SAFETY: timer_ptr points to an allocation of at least `Timer` size.
         unsafe { Timer::raw_get(timer_ptr) }
     }
+
+    /// Schedule the timer contained in the `Self` pointed to by `self_ptr`. If
+    /// it is already scheduled it is removed and inserted.
+    ///
+    /// # Safety
+    ///
+    /// `self_ptr` must point to a valid `Self`.
+    unsafe fn schedule(self_ptr: *const Self, expires: Ktime) {
+        unsafe {
+            bindings::hrtimer_start_range_ns(
+                Self::c_timer_ptr(self_ptr).cast_mut(),
+                expires.to_ns(),
+                0,
+                bindings::hrtimer_mode_HRTIMER_MODE_REL,
+            );
+        }
+    }
 }
 
 /// Use to implement the [`HasTimer<T>`] trait.
@@ -229,3 +327,5 @@ unsafe fn raw_get_timer(ptr: *const Self) ->
         }
     }
 }
+
+mod arc;
diff --git a/rust/kernel/hrtimer/arc.rs b/rust/kernel/hrtimer/arc.rs
new file mode 100644
index 000000000000..80f6c20f95a9
--- /dev/null
+++ b/rust/kernel/hrtimer/arc.rs
@@ -0,0 +1,87 @@
+// SPDX-License-Identifier: GPL-2.0
+
+use super::HasTimer;
+use super::RawTimerCallback;
+use super::Timer;
+use super::TimerCallback;
+use super::TimerHandle;
+use super::TimerPointer;
+use crate::sync::Arc;
+use crate::time::Ktime;
+
+/// A handle for an `Arc<HasTimer<U>>` returned by a call to
+/// [`TimerPointer::schedule`].
+pub struct ArcTimerHandle<U>
+where
+    U: HasTimer<U>,
+{
+    pub(crate) inner: Arc<U>,
+}
+
+// SAFETY: We implement drop below, and we cancel the timer in the drop
+// implementation.
+unsafe impl<U> TimerHandle for ArcTimerHandle<U>
+where
+    U: HasTimer<U>,
+{
+    fn cancel(&mut self) -> bool {
+        let self_ptr = self.inner.as_ptr();
+
+        // SAFETY: As we obtained `self_ptr` from a valid reference above, it
+        // must point to a valid `U`.
+        let timer_ptr = unsafe { <U as HasTimer<U>>::raw_get_timer(self_ptr) };
+
+        // SAFETY: As `timer_ptr` points into `U` and `U` is valid, `timer_ptr`
+        // must point to a valid `Timer` instance.
+        unsafe { Timer::<U>::raw_cancel(timer_ptr) }
+    }
+}
+
+impl<U> Drop for ArcTimerHandle<U>
+where
+    U: HasTimer<U>,
+{
+    fn drop(&mut self) {
+        self.cancel();
+    }
+}
+
+impl<U> TimerPointer for Arc<U>
+where
+    U: Send + Sync,
+    U: HasTimer<U>,
+    U: for<'a> TimerCallback<CallbackTarget<'a> = Self>,
+{
+    type TimerHandle = ArcTimerHandle<U>;
+
+    fn schedule(self, expires: Ktime) -> ArcTimerHandle<U> {
+        // SAFETY: Since we generate the pointer passed to `schedule` from a
+        // valid reference, it is a valid pointer.
+        unsafe { U::schedule(self.as_ptr(), expires) };
+
+        ArcTimerHandle { inner: self }
+    }
+}
+
+impl<U> RawTimerCallback for Arc<U>
+where
+    U: HasTimer<U>,
+    U: for<'a> TimerCallback<CallbackTarget<'a> = Self>,
+{
+    unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart {
+        // `Timer` is `repr(transparent)`
+        let timer_ptr = ptr.cast::<kernel::hrtimer::Timer<U>>();
+
+        // SAFETY: By C API contract `ptr` is the pointer we passed when
+        // queuing the timer, so it is a `Timer<T>` embedded in a `T`.
+        let data_ptr = unsafe { U::timer_container_of(timer_ptr) };
+
+        // SAFETY: `data_ptr` points to the `U` that was used to queue the
+        // timer. This `U` is contained in an `Arc`.
+        let receiver = unsafe { Arc::clone_from_raw(data_ptr) };
+
+        U::run(receiver);
+
+        bindings::hrtimer_restart_HRTIMER_NORESTART
+    }
+}
-- 
2.46.0