platform/
thread_pool.rs

1// SPDX-FileCopyrightText: 2025 Jens Pitkänen <jens.pitkanen@helsinki.fi>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5//! Thread pool for running tasks on other threads.
6//!
7//! [`ThreadPool`] implements a FIFO task queue where the tasks are executed on
8//! other threads, the amount of threads depending on how the [`ThreadPool`] is
9//! constructed. As a fallback, single-threaded platforms are supported by
10//! simply running the task in [`ThreadPool::join_task`].
11//!
12//! This module doesn't do any allocation, and isn't very usable on its own,
13//! it's intended to be used alongside platform-provided threading functions, by
14//! the engine, to construct multithreading utilities.
15
16use core::{marker::PhantomData, mem::transmute};
17
18use crate::{
19    channel::{Receiver, Sender},
20    Box,
21};
22
23/// Handle to a running or waiting task on a [`ThreadPool`].
24///
25/// These should be passed into [`ThreadPool::join_task`] in the same order as
26/// they were created with [`ThreadPool::spawn_task`].
27#[derive(Debug)]
28pub struct TaskHandle<T: 'static> {
29    thread_index: usize,
30    task_position: u64,
31    _type_holder: PhantomData<&'static T>,
32}
33
34/// Packets sent between threads to coordinate a [`ThreadPool`].
35pub struct TaskInFlight {
36    /// Whether or not [`Task::run`] has been run for this task.
37    finished: bool,
38    /// Extracted from: `Box<T>`.
39    data: *mut (),
40    /// Cast from: `fn(&mut T)`.
41    func: *const (),
42    /// Pass `self.func` and `self.data` in here to call the function with the
43    /// right types.
44    func_proxy: fn(func: *const (), data: *mut ()),
45    /// Can be set to true by the processing thread to signal that the thread
46    /// panicked. This will cause the join function to panic with "a thread in
47    /// the thread pool panicked" when joining this task.
48    thread_panicked: bool,
49}
50
51impl TaskInFlight {
52    /// Process the task in this container. Returns false if the task has
53    /// already been ran, in which case this function does nothing.
54    pub fn run(&mut self) -> bool {
55        if !self.finished {
56            (self.func_proxy)(self.func, self.data);
57            self.finished = true;
58            true
59        } else {
60            false
61        }
62    }
63
64    /// Signals the thread pool that the thread responsible for running this
65    /// task panicked. This can be used to propagate the panic to the main
66    /// thread.
67    pub fn signal_panic(&mut self) {
68        self.thread_panicked = true;
69    }
70
71    /// Panics if the thread running this task has panicked, runs the task if
72    /// the task hasn't been ran and it didn't panic, and finally, returns the
73    /// data operated on by this task. Called by [`ThreadPool::join_task`].
74    ///
75    /// ### Safety
76    /// The type parameter `T` must match the original type parameter `T` of
77    /// [`ThreadPool::spawn_task`] exactly.
78    unsafe fn join<T>(mut self, run_if_not_finished: bool) -> Box<T> {
79        if self.thread_panicked {
80            panic!("a thread in the thread pool panicked");
81        }
82
83        if !self.finished && run_if_not_finished {
84            self.run();
85        }
86
87        // Safety: the *mut c_void was originally casted from a *mut T which in
88        // turn was from a Box<T>, so this pointer has already been guaranteed
89        // to live long enough. It is also not shared anywhere outside of this
90        // struct, so this is definitely a unique reference to the memory.
91        unsafe { Box::from_ptr(self.data as *mut T) }
92    }
93}
94
95// Safety: the only non-Sync field, the data pointer, points to the T of a
96// Box<T: Sync>.
97unsafe impl Sync for TaskInFlight {}
98
99/// The sending half of a [`TaskChannel`].
100pub type TaskSender = Sender<TaskInFlight>;
101/// The receiving half of a [`TaskChannel`].
102pub type TaskReceiver = Receiver<TaskInFlight>;
103/// Channel used by [`ThreadPool`] for communicating with the processing
104/// threads.
105///
106/// Passed into
107/// [`Platform::spawn_pool_thread`](crate::Platform::spawn_pool_thread).
108pub type TaskChannel = (TaskSender, TaskReceiver);
109
110/// State held by [`ThreadPool`] for sending and receiving [`TaskInFlight`]s
111/// between it and a thread.
112///
113/// Returned from
114/// [`Platform::spawn_pool_thread`](crate::Platform::spawn_pool_thread),
115/// multiple of these are used to create a [`ThreadPool`].
116pub struct ThreadState {
117    /// For sending tasks to the thread.
118    sender: TaskSender,
119    /// For getting tasks results back from the thread.
120    receiver: TaskReceiver,
121    /// The amount of tasks sent via `sender`. (Used for picking
122    /// [`TaskHandle::task_position`] for send).
123    sent_count: u64,
124    /// The amount of tasks received via `receiver`. (Used for checking
125    /// [`TaskHandle::task_position`] on recv).
126    recv_count: u64,
127}
128
129impl ThreadState {
130    /// Creates a new [`ThreadState`] from the relevant channel endpoints.
131    ///
132    /// `sender_to_thread` is used to send tasks to the thread, while
133    /// `receiver_from_thread` is used to receive finished tasks, so there
134    /// should be two channels for each thread.
135    ///
136    /// To implement a simple single-threaded [`ThreadPool`], the sender and
137    /// receiver of just one channel could be passed here, in which case
138    /// [`ThreadPool`] will run the task when joining that task in
139    /// [`ThreadPool::join_task`].
140    pub fn new(sender_to_thread: TaskSender, receiver_from_thread: TaskReceiver) -> ThreadState {
141        ThreadState {
142            sender: sender_to_thread,
143            receiver: receiver_from_thread,
144            sent_count: 0,
145            recv_count: 0,
146        }
147    }
148}
149
150/// Thread pool for running compute-intensive tasks in parallel.
151///
152/// Note that the tasks are run in submission order (on multiple threads, if
153/// available), so a task that e.g. blocks on a file read will prevent other
154/// tasks from running.
155pub struct ThreadPool {
156    next_thread_index: usize,
157    threads: Box<[ThreadState]>,
158}
159
160impl ThreadPool {
161    /// Creates a new [`ThreadPool`], returning None if the channels don't have
162    /// matching capacities.
163    pub fn new(threads: Box<[ThreadState]>) -> Option<ThreadPool> {
164        // Check that each channel has the same capacity
165        let mut capacity = None;
166        for thread in threads.iter() {
167            if let Some(capacity) = capacity {
168                if thread.receiver.capacity() != capacity || thread.sender.capacity() != capacity {
169                    return None;
170                }
171            } else if thread.receiver.capacity() != thread.sender.capacity() {
172                return None;
173            } else {
174                capacity = Some(thread.receiver.capacity());
175            }
176        }
177
178        Some(ThreadPool {
179            next_thread_index: 0,
180            threads,
181        })
182    }
183
184    /// Returns the amount of threads in this thread pool.
185    pub fn thread_count(&self) -> usize {
186        self.threads.len()
187    }
188
189    /// Returns the length of a task queue.
190    ///
191    /// In total, tasks can be spawned without joining up to this amount times
192    /// the thread count.
193    pub fn queue_len(&self) -> usize {
194        if let Some(thread) = self.threads.first() {
195            thread.receiver.capacity() // Checked in new() to match all other channels too
196        } else {
197            0
198        }
199    }
200
201    /// Returns true if the thread pool has any pending tasks in the queues.
202    pub fn has_pending(&self) -> bool {
203        self.threads
204            .iter()
205            .any(|thread| thread.recv_count != thread.sent_count)
206    }
207
208    /// Resets the counter used to assign tasks to different threads.
209    ///
210    /// After calling this, the next [`ThreadPool::spawn_task`] is sent off to
211    /// the first thread, instead of whichever value the counter is on now.
212    pub fn reset_thread_counter(&mut self) {
213        self.next_thread_index = 0;
214    }
215
216    /// Schedules the function to be ran on a thread in this pool, passing in
217    /// the data as an argument, if they fit in the task queue.
218    ///
219    /// The function passed in is only ever ran once. In a single-threaded
220    /// environment, it is ran when `join_task` is called for this task,
221    /// otherwise it's ran whenever the thread gets to it.
222    ///
223    /// The threads are not load-balanced, the assigned thread is simply rotated
224    /// on each call of this function.
225    ///
226    /// Tasks should be joined ([`ThreadPool::join_task`]) in the same order as
227    /// they were spawned, as the results need to be received in sending order
228    /// for each thread. However, this ordering requirement only applies
229    /// per-thread, so [`ThreadPool::thread_count`] subsequent spawns can be
230    /// joined in any order amongst themselves — whether this is useful or not,
231    /// is up to the joiner.
232    pub fn spawn_task<T>(
233        &mut self,
234        data: Box<T>,
235        func: fn(&mut T),
236    ) -> Result<TaskHandle<T>, Box<T>> {
237        if self.threads.is_empty() {
238            return Err(data);
239        }
240
241        let thread_index = self.next_thread_index;
242        let task_position = self.threads[thread_index].sent_count;
243
244        let func = func as *const (); // type erase for TaskInFlight
245
246        let data: *mut T = data.into_ptr();
247        let data = data as *mut (); // type erase for TaskInFlight
248
249        fn proxy<T>(func: *const (), data: *mut ()) {
250            // Safety: this pointer is cast from the destination type `fn(&mut
251            // T)` above, and transmuting pointers to fn pointers is ok
252            // according to the [fn
253            // docs](https://doc.rust-lang.org/core/primitive.fn.html#casting-to-and-from-integers).
254            let func = unsafe { transmute::<*const (), fn(&mut T)>(func) };
255            // Safety: this pointer is the same one created above from a Box<T>
256            // (which had unique access to this memory), and it's safe to create
257            // a mutable borrow of it, as this is the only function that will do
258            // anything with the pointer, and this function is only ever called
259            // once for any particular task.
260            let data: &mut T = unsafe { &mut *(data as *mut T) };
261            func(data);
262        }
263
264        let task = TaskInFlight {
265            finished: false,
266            data,
267            func,
268            func_proxy: proxy::<T>,
269            thread_panicked: false,
270        };
271
272        (self.threads[thread_index].sender)
273            .send(task)
274            // Safety: T is definitely correct, we just created this task with
275            // the same type parameter.
276            .map_err(|task| unsafe { task.join::<T>(false) })?;
277
278        self.threads[thread_index].sent_count = task_position
279            .checked_add(1)
280            .expect("thread pool sent_count should not overflow a u64");
281        self.next_thread_index = (thread_index + 1) % self.thread_count();
282
283        Ok(TaskHandle {
284            thread_index,
285            task_position,
286            _type_holder: PhantomData,
287        })
288    }
289
290    /// Blocks on and returns the task passed into [`ThreadPool::spawn_task`],
291    /// if it's next in the queue for the thread it's running on.
292    ///
293    /// The `Err` variant signifies that there's some other task that should be
294    /// joined before this one. When spawning and joining tasks in FIFO order,
295    /// this never returns an `Err`.
296    ///
297    /// Depending on the [`ThreadState`]s passed into the constructor, this
298    /// could either call the function (if it's a one-channel state), or wait
299    /// until another thread has finished calling it (if it's a two-channel
300    /// state that actually has a corresponding parallel thread).
301    pub fn join_task<T>(&mut self, handle: TaskHandle<T>) -> Result<Box<T>, TaskHandle<T>> {
302        let current_recv_count = self.threads[handle.thread_index].recv_count;
303
304        if handle.task_position != current_recv_count {
305            return Err(handle);
306        }
307
308        let task = self.threads[handle.thread_index].receiver.recv();
309        // Safety: the TaskHandle returned from the spawn function
310        // has the correct T for this, and since we've already
311        // checked the thread index and task position, we know this
312        // matches the original spawn call (and thus its type
313        // parameter) for this data.
314        let data = unsafe { task.join::<T>(true) };
315        self.threads[handle.thread_index].recv_count += 1;
316        Ok(data)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    extern crate alloc;
323
324    use crate::channel::leak_channel;
325    use alloc::boxed::Box;
326
327    use super::{TaskInFlight, ThreadPool, ThreadState};
328
329    #[derive(Debug)]
330    struct ExampleData(u32);
331
332    #[test]
333    fn single_threaded_pool_works() {
334        // Generally you'd create two channels for thread<->thread
335        // communication, but in a single-threaded situation, the channel works
336        // as a simple work queue.
337        let (tx, rx) = leak_channel::<TaskInFlight>(1);
338        let thread_state = ThreadState::new(tx, rx);
339        let threads = Box::leak(Box::new([thread_state]));
340        let mut thread_pool = ThreadPool::new(crate::Box::from_mut(threads)).unwrap();
341
342        let mut data = ExampleData(0);
343        {
344            // Safety: `data` is dropped after this scope, and this Box does not
345            // leave this scope, so `data` outlives this Box.
346            let data_boxed: crate::Box<ExampleData> =
347                unsafe { crate::Box::from_ptr(&raw mut data) };
348            assert_eq!(0, data_boxed.0);
349
350            let handle = thread_pool.spawn_task(data_boxed, |n| n.0 = 1).unwrap();
351            let data_boxed = thread_pool.join_task(handle).unwrap();
352            assert_eq!(1, data_boxed.0);
353        }
354        #[allow(clippy::drop_non_drop)]
355        drop(data); // `data` lives at least until here, at which point the unsafe box has been dropped
356    }
357}