Next Generation WASM Microkernel Operating System

feat(kasync): arbitrary task metadata. Less safe but less messy (#527)

authored by

Jonas Kruckenberg and committed by
GitHub
6417ba07 844247a5

+157 -53
+26 -2
libs/kasync/src/executor.rs
··· 162 162 /// # Errors 163 163 /// 164 164 /// Returns [`AllocError`] when allocation of the task fails. 165 - pub fn try_spawn<F>(&'static self, future: F) -> Result<JoinHandle<F::Output>, SpawnError> 165 + #[track_caller] 166 + pub fn try_spawn<F>(&'static self, future: F) -> Result<JoinHandle<F::Output, ()>, SpawnError> 166 167 where 167 168 F: Future + Send + 'static, 168 169 F::Output: Send + 'static, 169 170 { 170 171 self.build_task().try_spawn(future) 172 + } 173 + 174 + /// Attempt spawn this [`Future`] with the provided metadata onto this executor. 175 + /// 176 + /// This method returns a [`TaskRef`] which can be used to spawn it onto an [`crate::executor::Executor`] 177 + /// and a [`JoinHandle`] which can be used to await the futures output as well as control some aspects 178 + /// of its runtime behaviour (such as cancelling it). 179 + /// 180 + /// # Errors 181 + /// 182 + /// Returns [`AllocError`] when allocation of the task fails. 183 + #[track_caller] 184 + pub fn try_spawn_with_metadata<F, M>( 185 + &'static self, 186 + future: F, 187 + metadata: M, 188 + ) -> Result<JoinHandle<F::Output, M>, SpawnError> 189 + where 190 + F: Future + Send, 191 + F::Output: Send, 192 + M: Send, 193 + { 194 + self.build_task().try_spawn_with_metadata(future, metadata) 171 195 } 172 196 } 173 197 ··· 564 588 static ref EXEC: Executor = Executor::new().unwrap(); 565 589 } 566 590 567 - let (tx, rx) = loom::sync::mpsc::channel::<JoinHandle<u32>>(); 591 + let (tx, rx) = loom::sync::mpsc::channel::<JoinHandle<u32, ()>>(); 568 592 569 593 let h0 = loom::thread::spawn(move || { 570 594 let tid = loom::thread::current().id();
+56 -16
libs/kasync/src/task.rs
··· 81 81 pub struct TaskRef(NonNull<Header>); 82 82 83 83 #[repr(C)] 84 - struct Task<F: Future>(CachePadded<TaskInner<F>>); 84 + struct Task<F: Future, M>(CachePadded<TaskInner<F, M>>); 85 85 86 86 #[repr(C)] 87 - struct TaskInner<F: Future> { 87 + struct TaskInner<F: Future, M> { 88 88 /// This must be the first field of the `Task` struct! 89 - header: Header, 89 + header_and_metadata: HeaderAndMetadata<M>, 90 90 91 91 /// The future that the task is running. 92 92 /// ··· 143 143 join_waker: UnsafeCell<Option<Waker>>, 144 144 } 145 145 146 + #[repr(C)] 147 + struct HeaderAndMetadata<M> { 148 + /// This must be the first field of the `HeaderAndMetadata` struct! 149 + header: Header, 150 + metadata: M, 151 + } 152 + 146 153 /// The current lifecycle stage of the future. Either the future itself or its output. 147 154 #[repr(C)] // https://github.com/rust-lang/miri/issues/3780 148 155 enum Stage<F: Future> { ··· 208 215 self.header().id 209 216 } 210 217 218 + /// # Safety 219 + /// 220 + /// The caller must ensure the generic argument matches the metadata type this task got created with. 221 + pub unsafe fn metadata<M>(&self) -> &M { 222 + // Safety: ensured by caller 223 + unsafe { &self.0.cast::<HeaderAndMetadata<M>>().as_ref().metadata } 224 + } 225 + 211 226 /// Returns `true` when this task has run to completion. 212 227 pub fn is_complete(&self) -> bool { 213 228 self.state() ··· 347 362 // ===== private methods ===== 348 363 349 364 #[track_caller] 350 - fn new_allocated<F>(task: Box<Task<F>>) -> (Self, JoinHandle<F::Output>) 365 + fn new_allocated<F, M>(task: Box<Task<F, M>>) -> (Self, JoinHandle<F::Output, M>) 351 366 where 352 367 F: Future, 353 368 { ··· 547 562 548 563 // === impl Task === 549 564 550 - impl<F: Future> Task<F> { 565 + impl<F: Future, M> Task<F, M> { 551 566 const TASK_VTABLE: VTable = VTable { 552 567 poll: Self::poll, 553 568 poll_join: Self::poll_join, ··· 555 570 }; 556 571 557 572 loom_const_fn! { 558 - pub const fn new(future: F, task_id: Id, span: tracing::Span) -> Self { 573 + pub const fn new(future: F, task_id: Id, span: tracing::Span, metadata: M) -> Self { 559 574 let inner = TaskInner { 560 - header: Header { 561 - state: State::new(), 562 - vtable: &Self::TASK_VTABLE, 563 - id: task_id, 564 - run_queue_links: mpsc_queue::Links::new(), 565 - span, 566 - scheduler: UnsafeCell::new(None) 575 + header_and_metadata: HeaderAndMetadata { 576 + header: Header { 577 + state: State::new(), 578 + vtable: &Self::TASK_VTABLE, 579 + id: task_id, 580 + run_queue_links: mpsc_queue::Links::new(), 581 + span, 582 + scheduler: UnsafeCell::new(None) 583 + }, 584 + metadata 567 585 }, 568 586 stage: UnsafeCell::new(Stage::Pending(future)), 569 587 join_waker: UnsafeCell::new(None), ··· 802 820 } 803 821 804 822 fn id(&self) -> &Id { 805 - &self.0.0.header.id 823 + &self.0.0.header_and_metadata.header.id 806 824 } 807 825 fn state(&self) -> &State { 808 - &self.0.0.header.state 826 + &self.0.0.header_and_metadata.header.state 809 827 } 810 828 #[inline] 811 829 fn span(&self) -> &tracing::Span { 812 - &self.0.0.header.span 830 + &self.0.0.header_and_metadata.header.span 813 831 } 814 832 } 815 833 ··· 905 923 .cast() 906 924 } 907 925 } 926 + 927 + #[cfg(test)] 928 + mod tests { 929 + use alloc::boxed::Box; 930 + 931 + use crate::loom; 932 + use crate::task::{Id, Task, TaskRef}; 933 + 934 + #[test] 935 + fn metadata() { 936 + loom::model(|| { 937 + let (t1, _) = TaskRef::new_allocated(Box::new(Task::new( 938 + async {}, 939 + Id::next(), 940 + tracing::Span::none(), 941 + 42usize, 942 + ))); 943 + 944 + assert_eq!(unsafe { *t1.metadata::<usize>() }, 42); 945 + }); 946 + } 947 + }
+35 -4
libs/kasync/src/task/builder.rs
··· 61 61 62 62 #[inline] 63 63 #[track_caller] 64 - fn build<F>(&self, future: F) -> Task<F> 64 + fn build<F, M>(&self, future: F, metadata: M) -> Task<F, M> 65 65 where 66 66 F: Future + Send, 67 67 F::Output: Send, 68 + M: Send, 68 69 { 69 70 let id = Id::next(); 70 71 ··· 80 81 loc.col = loc.column(), 81 82 ); 82 83 83 - Task::new(future, id, span) 84 + Task::new(future, id, span, metadata) 84 85 } 85 86 86 87 /// Attempt spawn this [`Future`] onto the executor. ··· 94 95 /// Returns [`AllocError`] when allocation of the task fails. 95 96 #[inline] 96 97 #[track_caller] 97 - pub fn try_spawn<F>(&self, future: F) -> Result<JoinHandle<F::Output>, SpawnError> 98 + pub fn try_spawn<F>(&self, future: F) -> Result<JoinHandle<F::Output, ()>, SpawnError> 99 + where 100 + F: Future + Send, 101 + F::Output: Send, 102 + { 103 + let task = self.build(future, ()); 104 + let task = Box::try_new(task)?; 105 + let (task, join) = TaskRef::new_allocated(task); 106 + 107 + (self.schedule)(task)?; 108 + 109 + Ok(join) 110 + } 111 + 112 + /// Attempt spawn this [`Future`] with the provided metadata onto the executor. 113 + /// 114 + /// This method returns a [`TaskRef`] which can be used to spawn it onto an [`crate::executor::Executor`] 115 + /// and a [`JoinHandle`] which can be used to await the futures output as well as control some aspects 116 + /// of its runtime behaviour (such as cancelling it). 117 + /// 118 + /// # Errors 119 + /// 120 + /// Returns [`AllocError`] when allocation of the task fails. 121 + #[inline] 122 + #[track_caller] 123 + pub fn try_spawn_with_metadata<F, M>( 124 + &self, 125 + future: F, 126 + metadata: M, 127 + ) -> Result<JoinHandle<F::Output, M>, SpawnError> 98 128 where 99 129 F: Future + Send, 100 130 F::Output: Send, 131 + M: Send, 101 132 { 102 - let task = self.build(future); 133 + let task = self.build(future, metadata); 103 134 let task = Box::try_new(task)?; 104 135 let (task, join) = TaskRef::new_allocated(task); 105 136
+40 -31
libs/kasync/src/task/join_handle.rs
··· 18 18 19 19 use crate::task::{Id, TaskRef}; 20 20 21 - pub struct JoinHandle<T> { 21 + pub struct JoinHandle<T, M> { 22 22 state: JoinHandleState, 23 23 id: Id, 24 - _p: PhantomData<T>, 24 + _p: PhantomData<(T, M)>, 25 25 } 26 - static_assertions::assert_impl_all!(JoinHandle<()>: Send); 26 + static_assertions::assert_impl_all!(JoinHandle<(), ()>: Send); 27 27 28 28 #[derive(Debug)] 29 29 enum JoinHandleState { ··· 46 46 47 47 // === impl JoinHandle === 48 48 49 - impl<T> UnwindSafe for JoinHandle<T> {} 49 + impl<T, M> UnwindSafe for JoinHandle<T, M> {} 50 50 51 - impl<T> RefUnwindSafe for JoinHandle<T> {} 51 + impl<T, M> RefUnwindSafe for JoinHandle<T, M> {} 52 52 53 - impl<T> Unpin for JoinHandle<T> {} 53 + impl<T, M> Unpin for JoinHandle<T, M> {} 54 54 55 - impl<T> Drop for JoinHandle<T> { 55 + impl<T, M> Drop for JoinHandle<T, M> { 56 56 fn drop(&mut self) { 57 57 // if the JoinHandle has not already been consumed, clear the join 58 58 // handle flag on the task. ··· 75 75 } 76 76 } 77 77 78 - impl<T> fmt::Debug for JoinHandle<T> 78 + impl<T, M> fmt::Debug for JoinHandle<T, M> 79 79 where 80 80 T: fmt::Debug, 81 81 { ··· 88 88 } 89 89 } 90 90 91 - impl<T> Future for JoinHandle<T> { 91 + impl<T, M> Future for JoinHandle<T, M> { 92 92 type Output = Result<T, JoinError<T>>; 93 93 94 94 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { ··· 121 121 122 122 // ==== PartialEq impls for JoinHandle/TaskRef ==== 123 123 124 - impl<T> PartialEq<TaskRef> for JoinHandle<T> { 124 + impl<T, M> PartialEq<TaskRef> for JoinHandle<T, M> { 125 125 fn eq(&self, other: &TaskRef) -> bool { 126 126 match self.state { 127 127 JoinHandleState::Task(ref task) => task == other, ··· 130 130 } 131 131 } 132 132 133 - impl<T> PartialEq<&'_ TaskRef> for JoinHandle<T> { 133 + impl<T, M> PartialEq<&'_ TaskRef> for JoinHandle<T, M> { 134 134 fn eq(&self, other: &&TaskRef) -> bool { 135 135 match self.state { 136 136 JoinHandleState::Task(ref task) => task == *other, ··· 139 139 } 140 140 } 141 141 142 - impl<T> PartialEq<JoinHandle<T>> for TaskRef { 143 - fn eq(&self, other: &JoinHandle<T>) -> bool { 142 + impl<T, M> PartialEq<JoinHandle<T, M>> for TaskRef { 143 + fn eq(&self, other: &JoinHandle<T, M>) -> bool { 144 144 match other.state { 145 145 JoinHandleState::Task(ref task) => self == task, 146 146 _ => false, ··· 148 148 } 149 149 } 150 150 151 - impl<T> PartialEq<&'_ JoinHandle<T>> for TaskRef { 152 - fn eq(&self, other: &&JoinHandle<T>) -> bool { 151 + impl<T, M> PartialEq<&'_ JoinHandle<T, M>> for TaskRef { 152 + fn eq(&self, other: &&JoinHandle<T, M>) -> bool { 153 153 match other.state { 154 154 JoinHandleState::Task(ref task) => self == task, 155 155 _ => false, ··· 159 159 160 160 // ==== PartialEq impls for JoinHandle/Id ==== 161 161 162 - impl<T> PartialEq<Id> for JoinHandle<T> { 162 + impl<T, M> PartialEq<Id> for JoinHandle<T, M> { 163 163 #[inline] 164 164 fn eq(&self, other: &Id) -> bool { 165 165 self.id == *other 166 166 } 167 167 } 168 168 169 - impl<T> PartialEq<JoinHandle<T>> for Id { 169 + impl<T, M> PartialEq<JoinHandle<T, M>> for Id { 170 170 #[inline] 171 - fn eq(&self, other: &JoinHandle<T>) -> bool { 171 + fn eq(&self, other: &JoinHandle<T, M>) -> bool { 172 172 *self == other.id 173 173 } 174 174 } 175 175 176 - impl<T> PartialEq<&'_ JoinHandle<T>> for Id { 176 + impl<T, M> PartialEq<&'_ JoinHandle<T, M>> for Id { 177 177 #[inline] 178 - fn eq(&self, other: &&JoinHandle<T>) -> bool { 178 + fn eq(&self, other: &&JoinHandle<T, M>) -> bool { 179 179 *self == other.id 180 180 } 181 181 } 182 182 183 - impl<T> JoinHandle<T> { 184 - pub(crate) fn new(task: TaskRef) -> Self { 185 - task.state().create_join_handle(); 186 - 187 - Self { 188 - id: task.id(), 189 - state: JoinHandleState::Task(task), 190 - _p: PhantomData, 191 - } 192 - } 193 - 183 + impl<T, M> JoinHandle<T, M> { 194 184 /// Cancels the task associated with the handle. 195 185 /// 196 186 /// Awaiting a cancelled task might complete as usual if the task was already completed at ··· 216 206 // `Future` impl for `JoinHandle` completed, and the task has 217 207 // _definitely_ completed. 218 208 _ => true, 209 + } 210 + } 211 + 212 + pub fn metadata(&self) -> Option<&M> { 213 + if let JoinHandleState::Task(ref task) = self.state { 214 + // Safety: we know this generic is correct through construction 215 + Some(unsafe { task.metadata::<M>() }) 216 + } else { 217 + None 218 + } 219 + } 220 + 221 + pub(crate) fn new(task: TaskRef) -> Self { 222 + task.state().create_join_handle(); 223 + 224 + Self { 225 + id: task.id(), 226 + state: JoinHandleState::Task(task), 227 + _p: PhantomData, 219 228 } 220 229 } 221 230 }