platform/thread_pool.rs
1// SPDX-FileCopyrightText: 2025 Jens Pitkänen <jens.pitkanen@helsinki.fi>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5//! Thread pool for running tasks on other threads.
6//!
7//! [`ThreadPool`] implements a FIFO task queue where the tasks are executed on
8//! other threads, the amount of threads depending on how the [`ThreadPool`] is
9//! constructed. As a fallback, single-threaded platforms are supported by
10//! simply running the task in [`ThreadPool::join_task`].
11//!
12//! This module doesn't do any allocation, and isn't very usable on its own,
13//! it's intended to be used alongside platform-provided threading functions, by
14//! the engine, to construct multithreading utilities.
15
16use core::{marker::PhantomData, mem::transmute};
17
18use crate::{
19 channel::{Receiver, Sender},
20 Box,
21};
22
23/// Handle to a running or waiting task on a [`ThreadPool`].
24///
25/// These should be passed into [`ThreadPool::join_task`] in the same order as
26/// they were created with [`ThreadPool::spawn_task`].
27#[derive(Debug)]
28pub struct TaskHandle<T: 'static> {
29 thread_index: usize,
30 task_position: u64,
31 _type_holder: PhantomData<&'static T>,
32}
33
34/// Packets sent between threads to coordinate a [`ThreadPool`].
35pub struct TaskInFlight {
36 /// Whether or not [`Task::run`] has been run for this task.
37 finished: bool,
38 /// Extracted from: `Box<T>`.
39 data: *mut (),
40 /// Cast from: `fn(&mut T)`.
41 func: *const (),
42 /// Pass `self.func` and `self.data` in here to call the function with the
43 /// right types.
44 func_proxy: fn(func: *const (), data: *mut ()),
45 /// Can be set to true by the processing thread to signal that the thread
46 /// panicked. This will cause the join function to panic with "a thread in
47 /// the thread pool panicked" when joining this task.
48 thread_panicked: bool,
49}
50
51impl TaskInFlight {
52 /// Process the task in this container. Returns false if the task has
53 /// already been ran, in which case this function does nothing.
54 pub fn run(&mut self) -> bool {
55 if !self.finished {
56 (self.func_proxy)(self.func, self.data);
57 self.finished = true;
58 true
59 } else {
60 false
61 }
62 }
63
64 /// Signals the thread pool that the thread responsible for running this
65 /// task panicked. This can be used to propagate the panic to the main
66 /// thread.
67 pub fn signal_panic(&mut self) {
68 self.thread_panicked = true;
69 }
70
71 /// Panics if the thread running this task has panicked, runs the task if
72 /// the task hasn't been ran and it didn't panic, and finally, returns the
73 /// data operated on by this task. Called by [`ThreadPool::join_task`].
74 ///
75 /// ### Safety
76 /// The type parameter `T` must match the original type parameter `T` of
77 /// [`ThreadPool::spawn_task`] exactly.
78 unsafe fn join<T>(mut self, run_if_not_finished: bool) -> Box<T> {
79 if self.thread_panicked {
80 panic!("a thread in the thread pool panicked");
81 }
82
83 if !self.finished && run_if_not_finished {
84 self.run();
85 }
86
87 // Safety: the *mut c_void was originally casted from a *mut T which in
88 // turn was from a Box<T>, so this pointer has already been guaranteed
89 // to live long enough. It is also not shared anywhere outside of this
90 // struct, so this is definitely a unique reference to the memory.
91 unsafe { Box::from_ptr(self.data as *mut T) }
92 }
93}
94
95// Safety: the only non-Sync field, the data pointer, points to the T of a
96// Box<T: Sync>.
97unsafe impl Sync for TaskInFlight {}
98
99/// The sending half of a [`TaskChannel`].
100pub type TaskSender = Sender<TaskInFlight>;
101/// The receiving half of a [`TaskChannel`].
102pub type TaskReceiver = Receiver<TaskInFlight>;
103/// Channel used by [`ThreadPool`] for communicating with the processing
104/// threads.
105///
106/// Passed into
107/// [`Platform::spawn_pool_thread`](crate::Platform::spawn_pool_thread).
108pub type TaskChannel = (TaskSender, TaskReceiver);
109
110/// State held by [`ThreadPool`] for sending and receiving [`TaskInFlight`]s
111/// between it and a thread.
112///
113/// Returned from
114/// [`Platform::spawn_pool_thread`](crate::Platform::spawn_pool_thread),
115/// multiple of these are used to create a [`ThreadPool`].
116pub struct ThreadState {
117 /// For sending tasks to the thread.
118 sender: TaskSender,
119 /// For getting tasks results back from the thread.
120 receiver: TaskReceiver,
121 /// The amount of tasks sent via `sender`. (Used for picking
122 /// [`TaskHandle::task_position`] for send).
123 sent_count: u64,
124 /// The amount of tasks received via `receiver`. (Used for checking
125 /// [`TaskHandle::task_position`] on recv).
126 recv_count: u64,
127}
128
129impl ThreadState {
130 /// Creates a new [`ThreadState`] from the relevant channel endpoints.
131 ///
132 /// `sender_to_thread` is used to send tasks to the thread, while
133 /// `receiver_from_thread` is used to receive finished tasks, so there
134 /// should be two channels for each thread.
135 ///
136 /// To implement a simple single-threaded [`ThreadPool`], the sender and
137 /// receiver of just one channel could be passed here, in which case
138 /// [`ThreadPool`] will run the task when joining that task in
139 /// [`ThreadPool::join_task`].
140 pub fn new(sender_to_thread: TaskSender, receiver_from_thread: TaskReceiver) -> ThreadState {
141 ThreadState {
142 sender: sender_to_thread,
143 receiver: receiver_from_thread,
144 sent_count: 0,
145 recv_count: 0,
146 }
147 }
148}
149
150/// Thread pool for running compute-intensive tasks in parallel.
151///
152/// Note that the tasks are run in submission order (on multiple threads, if
153/// available), so a task that e.g. blocks on a file read will prevent other
154/// tasks from running.
155pub struct ThreadPool {
156 next_thread_index: usize,
157 threads: Box<[ThreadState]>,
158}
159
160impl ThreadPool {
161 /// Creates a new [`ThreadPool`], returning None if the channels don't have
162 /// matching capacities.
163 pub fn new(threads: Box<[ThreadState]>) -> Option<ThreadPool> {
164 // Check that each channel has the same capacity
165 let mut capacity = None;
166 for thread in threads.iter() {
167 if let Some(capacity) = capacity {
168 if thread.receiver.capacity() != capacity || thread.sender.capacity() != capacity {
169 return None;
170 }
171 } else if thread.receiver.capacity() != thread.sender.capacity() {
172 return None;
173 } else {
174 capacity = Some(thread.receiver.capacity());
175 }
176 }
177
178 Some(ThreadPool {
179 next_thread_index: 0,
180 threads,
181 })
182 }
183
184 /// Returns the amount of threads in this thread pool.
185 pub fn thread_count(&self) -> usize {
186 self.threads.len()
187 }
188
189 /// Returns the length of a task queue.
190 ///
191 /// In total, tasks can be spawned without joining up to this amount times
192 /// the thread count.
193 pub fn queue_len(&self) -> usize {
194 if let Some(thread) = self.threads.first() {
195 thread.receiver.capacity() // Checked in new() to match all other channels too
196 } else {
197 0
198 }
199 }
200
201 /// Returns true if the thread pool has any pending tasks in the queues.
202 pub fn has_pending(&self) -> bool {
203 self.threads
204 .iter()
205 .any(|thread| thread.recv_count != thread.sent_count)
206 }
207
208 /// Resets the counter used to assign tasks to different threads.
209 ///
210 /// After calling this, the next [`ThreadPool::spawn_task`] is sent off to
211 /// the first thread, instead of whichever value the counter is on now.
212 pub fn reset_thread_counter(&mut self) {
213 self.next_thread_index = 0;
214 }
215
216 /// Schedules the function to be ran on a thread in this pool, passing in
217 /// the data as an argument, if they fit in the task queue.
218 ///
219 /// The function passed in is only ever ran once. In a single-threaded
220 /// environment, it is ran when `join_task` is called for this task,
221 /// otherwise it's ran whenever the thread gets to it.
222 ///
223 /// The threads are not load-balanced, the assigned thread is simply rotated
224 /// on each call of this function.
225 ///
226 /// Tasks should be joined ([`ThreadPool::join_task`]) in the same order as
227 /// they were spawned, as the results need to be received in sending order
228 /// for each thread. However, this ordering requirement only applies
229 /// per-thread, so [`ThreadPool::thread_count`] subsequent spawns can be
230 /// joined in any order amongst themselves — whether this is useful or not,
231 /// is up to the joiner.
232 pub fn spawn_task<T>(
233 &mut self,
234 data: Box<T>,
235 func: fn(&mut T),
236 ) -> Result<TaskHandle<T>, Box<T>> {
237 if self.threads.is_empty() {
238 return Err(data);
239 }
240
241 let thread_index = self.next_thread_index;
242 let task_position = self.threads[thread_index].sent_count;
243
244 let func = func as *const (); // type erase for TaskInFlight
245
246 let data: *mut T = data.into_ptr();
247 let data = data as *mut (); // type erase for TaskInFlight
248
249 fn proxy<T>(func: *const (), data: *mut ()) {
250 // Safety: this pointer is cast from the destination type `fn(&mut
251 // T)` above, and transmuting pointers to fn pointers is ok
252 // according to the [fn
253 // docs](https://doc.rust-lang.org/core/primitive.fn.html#casting-to-and-from-integers).
254 let func = unsafe { transmute::<*const (), fn(&mut T)>(func) };
255 // Safety: this pointer is the same one created above from a Box<T>
256 // (which had unique access to this memory), and it's safe to create
257 // a mutable borrow of it, as this is the only function that will do
258 // anything with the pointer, and this function is only ever called
259 // once for any particular task.
260 let data: &mut T = unsafe { &mut *(data as *mut T) };
261 func(data);
262 }
263
264 let task = TaskInFlight {
265 finished: false,
266 data,
267 func,
268 func_proxy: proxy::<T>,
269 thread_panicked: false,
270 };
271
272 (self.threads[thread_index].sender)
273 .send(task)
274 // Safety: T is definitely correct, we just created this task with
275 // the same type parameter.
276 .map_err(|task| unsafe { task.join::<T>(false) })?;
277
278 self.threads[thread_index].sent_count = task_position
279 .checked_add(1)
280 .expect("thread pool sent_count should not overflow a u64");
281 self.next_thread_index = (thread_index + 1) % self.thread_count();
282
283 Ok(TaskHandle {
284 thread_index,
285 task_position,
286 _type_holder: PhantomData,
287 })
288 }
289
290 /// Blocks on and returns the task passed into [`ThreadPool::spawn_task`],
291 /// if it's next in the queue for the thread it's running on.
292 ///
293 /// The `Err` variant signifies that there's some other task that should be
294 /// joined before this one. When spawning and joining tasks in FIFO order,
295 /// this never returns an `Err`.
296 ///
297 /// Depending on the [`ThreadState`]s passed into the constructor, this
298 /// could either call the function (if it's a one-channel state), or wait
299 /// until another thread has finished calling it (if it's a two-channel
300 /// state that actually has a corresponding parallel thread).
301 pub fn join_task<T>(&mut self, handle: TaskHandle<T>) -> Result<Box<T>, TaskHandle<T>> {
302 let current_recv_count = self.threads[handle.thread_index].recv_count;
303
304 if handle.task_position != current_recv_count {
305 return Err(handle);
306 }
307
308 let task = self.threads[handle.thread_index].receiver.recv();
309 // Safety: the TaskHandle returned from the spawn function
310 // has the correct T for this, and since we've already
311 // checked the thread index and task position, we know this
312 // matches the original spawn call (and thus its type
313 // parameter) for this data.
314 let data = unsafe { task.join::<T>(true) };
315 self.threads[handle.thread_index].recv_count += 1;
316 Ok(data)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 extern crate alloc;
323
324 use crate::channel::leak_channel;
325 use alloc::boxed::Box;
326
327 use super::{TaskInFlight, ThreadPool, ThreadState};
328
329 #[derive(Debug)]
330 struct ExampleData(u32);
331
332 #[test]
333 fn single_threaded_pool_works() {
334 // Generally you'd create two channels for thread<->thread
335 // communication, but in a single-threaded situation, the channel works
336 // as a simple work queue.
337 let (tx, rx) = leak_channel::<TaskInFlight>(1);
338 let thread_state = ThreadState::new(tx, rx);
339 let threads = Box::leak(Box::new([thread_state]));
340 let mut thread_pool = ThreadPool::new(crate::Box::from_mut(threads)).unwrap();
341
342 let mut data = ExampleData(0);
343 {
344 // Safety: `data` is dropped after this scope, and this Box does not
345 // leave this scope, so `data` outlives this Box.
346 let data_boxed: crate::Box<ExampleData> =
347 unsafe { crate::Box::from_ptr(&raw mut data) };
348 assert_eq!(0, data_boxed.0);
349
350 let handle = thread_pool.spawn_task(data_boxed, |n| n.0 = 1).unwrap();
351 let data_boxed = thread_pool.join_task(handle).unwrap();
352 assert_eq!(1, data_boxed.0);
353 }
354 #[allow(clippy::drop_non_drop)]
355 drop(data); // `data` lives at least until here, at which point the unsafe box has been dropped
356 }
357}