[PATCH 2/6] rust: hrtimer: Add HrTimerCallbackContext and ::forward()

Lyude Paul posted 6 patches 1 day, 14 hours ago
[PATCH 2/6] rust: hrtimer: Add HrTimerCallbackContext and ::forward()
Posted by Lyude Paul 1 day, 14 hours ago
With Linux's hrtimer API, certain functions require we either acquire
proper locking to call specific methods - or that we call said methods from
the context of the timer callback. hrtimer_forward() is one of these
functions, so we start by adding a new HrTimerCallbackContext type which
provides a way of calling these methods that is inaccessible outside of
hrtimer callbacks.

Signed-off-by: Lyude Paul <lyude@redhat.com>
---
 rust/kernel/time/hrtimer.rs         | 50 +++++++++++++++++++++++++++--
 rust/kernel/time/hrtimer/arc.rs     |  7 +++-
 rust/kernel/time/hrtimer/pin.rs     |  7 +++-
 rust/kernel/time/hrtimer/pin_mut.rs |  9 ++++--
 rust/kernel/time/hrtimer/tbox.rs    |  7 +++-
 5 files changed, 73 insertions(+), 7 deletions(-)

diff --git a/rust/kernel/time/hrtimer.rs b/rust/kernel/time/hrtimer.rs
index 4fc49f1931259..c92b10524f892 100644
--- a/rust/kernel/time/hrtimer.rs
+++ b/rust/kernel/time/hrtimer.rs
@@ -69,7 +69,7 @@
 
 use super::ClockId;
 use crate::{init::PinInit, prelude::*, time::Ktime, types::Opaque};
-use core::marker::PhantomData;
+use core::{marker::PhantomData, ptr::NonNull};
 
 /// A timer backed by a C `struct hrtimer`.
 ///
@@ -279,7 +279,10 @@ pub trait HrTimerCallback {
     type Pointer<'a>: RawHrTimerCallback;
 
     /// Called by the timer logic when the timer fires.
-    fn run(this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>) -> HrTimerRestart
+    fn run<T>(
+        this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>,
+        ctx: HrTimerCallbackContext<'_, T>
+    ) -> HrTimerRestart
     where
         Self: Sized;
 }
@@ -470,6 +473,49 @@ fn into_c(self) -> bindings::hrtimer_mode {
     }
 }
 
+/// Privileged smart-pointer for a [`HrTimer`] callback context.
+///
+/// This provides access to various methods for a [`HrTimer`] which can only be safely called within
+/// its callback.
+///
+/// # Invariants
+///
+/// * The existence of this type means the caller is currently within the callback for a
+///   [`HrTimer`].
+/// * `self.0` always points to a live instance of [`HrTimer<T>`].
+pub struct HrTimerCallbackContext<'a, T>(NonNull<HrTimer<T>>, PhantomData<&'a ()>);
+
+impl<'a, T> HrTimerCallbackContext<'a, T> {
+    /// Create a new [`HrTimerCallbackContext`].
+    ///
+    /// # Safety
+    ///
+    /// This function relies on the caller being within the context of a timer callback, so it must
+    /// not be used anywhere except for within implementations of [`RawHrTimerCallback::run`]. The
+    /// caller promises that `timer` points to a valid initialized instance of
+    /// [`bindings::hrtimer`].
+    pub(crate) unsafe fn from_raw(timer: *mut HrTimer<T>) -> Self {
+        // SAFETY: The caller guarantees `timer` is a valid pointer to an initialized
+        // `bindings::hrtimer`
+        Self(unsafe { NonNull::new_unchecked(timer) }, PhantomData)
+    }
+
+    /// Get the raw `bindings::hrtimer` pointer for this [`HrTimerCallbackContext`].
+    pub(crate) fn raw_get_timer(&self) -> *mut bindings::hrtimer {
+        // SAFETY: By our type invariants, `self.0` always points to a valid [`HrTimer<T>`].
+        unsafe { HrTimer::raw_get(self.0.as_ptr()) }
+    }
+
+    /// Forward the timer expiry so it will expire in the future.
+    ///
+    /// Note that this does not requeue the timer, it simply updates its expiry value. It returns
+    /// the number of overruns that have occurred as a result of the expiry change.
+    pub fn forward(&self, now: Ktime, interval: Ktime) -> u64 {
+        // SAFETY: The C API requirements for this function are fulfilled by our type invariants.
+        unsafe { bindings::hrtimer_forward(self.raw_get_timer(), now.to_ns(), interval.to_ns()) }
+    }
+}
+
 /// Use to implement the [`HasHrTimer<T>`] trait.
 ///
 /// See [`module`] documentation for an example.
diff --git a/rust/kernel/time/hrtimer/arc.rs b/rust/kernel/time/hrtimer/arc.rs
index 4a984d85b4a10..7dd9f46a0720d 100644
--- a/rust/kernel/time/hrtimer/arc.rs
+++ b/rust/kernel/time/hrtimer/arc.rs
@@ -3,6 +3,7 @@
 use super::HasHrTimer;
 use super::HrTimer;
 use super::HrTimerCallback;
+use super::HrTimerCallbackContext;
 use super::HrTimerHandle;
 use super::HrTimerPointer;
 use super::RawHrTimerCallback;
@@ -95,6 +96,10 @@ impl<T> RawHrTimerCallback for Arc<T>
         //    allocation from other `Arc` clones.
         let receiver = unsafe { ArcBorrow::from_raw(data_ptr) };
 
-        T::run(receiver).into_c()
+        // SAFETY: By C API contract `ptr` is the pointer we passed when queuing the timer, so it is
+        // a `HrTimer<T>` embedded in a `T`.
+        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };
+
+        T::run(receiver, context).into_c()
     }
 }
diff --git a/rust/kernel/time/hrtimer/pin.rs b/rust/kernel/time/hrtimer/pin.rs
index f760db265c7b5..a8e1b76bf0736 100644
--- a/rust/kernel/time/hrtimer/pin.rs
+++ b/rust/kernel/time/hrtimer/pin.rs
@@ -3,6 +3,7 @@
 use super::HasHrTimer;
 use super::HrTimer;
 use super::HrTimerCallback;
+use super::HrTimerCallbackContext;
 use super::HrTimerHandle;
 use super::RawHrTimerCallback;
 use super::UnsafeHrTimerPointer;
@@ -99,6 +100,10 @@ impl<'a, T> RawHrTimerCallback for Pin<&'a T>
         // here.
         let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) };
 
-        T::run(receiver_pin).into_c()
+        // SAFETY: By C API contract `ptr` is the pointer we passed when queuing the timer, so it is
+        // a `HrTimer<T>` embedded in a `T`.
+        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };
+
+        T::run(receiver_pin, context).into_c()
     }
 }
diff --git a/rust/kernel/time/hrtimer/pin_mut.rs b/rust/kernel/time/hrtimer/pin_mut.rs
index 90c0351d62e4b..2dd2ebfd7efaf 100644
--- a/rust/kernel/time/hrtimer/pin_mut.rs
+++ b/rust/kernel/time/hrtimer/pin_mut.rs
@@ -1,7 +1,8 @@
 // SPDX-License-Identifier: GPL-2.0
 
 use super::{
-    HasHrTimer, HrTimer, HrTimerCallback, HrTimerHandle, RawHrTimerCallback, UnsafeHrTimerPointer,
+    HasHrTimer, HrTimer, HrTimerCallback, HrTimerCallbackContext, HrTimerHandle, RawHrTimerCallback,
+    UnsafeHrTimerPointer,
 };
 use crate::time::Ktime;
 use core::{marker::PhantomData, pin::Pin, ptr::NonNull};
@@ -103,6 +104,10 @@ impl<'a, T> RawHrTimerCallback for Pin<&'a mut T>
         // here.
         let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) };
 
-        T::run(receiver_pin).into_c()
+        // SAFETY: By C API contract `ptr` is the pointer we passed when queuing the timer, so it is
+        // a `HrTimer<T>` embedded in a `T`.
+        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };
+
+        T::run(receiver_pin, context).into_c()
     }
 }
diff --git a/rust/kernel/time/hrtimer/tbox.rs b/rust/kernel/time/hrtimer/tbox.rs
index 2071cae072342..e3214f7173beb 100644
--- a/rust/kernel/time/hrtimer/tbox.rs
+++ b/rust/kernel/time/hrtimer/tbox.rs
@@ -3,6 +3,7 @@
 use super::HasHrTimer;
 use super::HrTimer;
 use super::HrTimerCallback;
+use super::HrTimerCallbackContext;
 use super::HrTimerHandle;
 use super::HrTimerPointer;
 use super::RawHrTimerCallback;
@@ -115,6 +116,10 @@ impl<T, A> RawHrTimerCallback for Pin<Box<T, A>>
         //   `data_ptr` exist.
         let data_mut_ref = unsafe { Pin::new_unchecked(&mut *data_ptr) };
 
-        T::run(data_mut_ref).into_c()
+        // SAFETY: By C API contract `ptr` is the pointer we passed when queuing the timer, so it is
+        // a `HrTimer<T>` embedded in a `T`.
+        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };
+
+        T::run(data_mut_ref, context).into_c()
     }
 }
-- 
2.48.1
Re: [PATCH 2/6] rust: hrtimer: Add HrTimerCallbackContext and ::forward()
Posted by Thomas Gleixner 1 day, 1 hour ago
On Wed, Apr 02 2025 at 17:40, Lyude Paul wrote:

> With Linux's hrtimer API, certain functions require we either acquire
> proper locking to call specific methods - or that we call said methods from
> the context of the timer callback. hrtimer_forward() is one of these
> functions, so we start by adding a new HrTimerCallbackContext type which
> provides a way of calling these methods that is inaccessible outside of
> hrtimer callbacks.

Just for completeness:

When hrtimer_forward() is invoked from non-callback context, then there
is not necessarily a locking requirement. The caller has to make sure
that the timer is neither armed, nor running the callback.

     hrtimer_cancel();
     hrtimer_forward();

is a legitimate sequence, if there is no way that the timer is re-armed
concurrently. That works just without locks.

That said, I really like that callback context concept you are doing!

Thanks,

        tglx