engine/
multithreading.rs

1// SPDX-FileCopyrightText: 2025 Jens Pitkänen <jens.pitkanen@helsinki.fi>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5use core::{mem::MaybeUninit, slice};
6
7use arrayvec::ArrayVec;
8use platform::{
9    thread_pool::{TaskHandle, ThreadPool},
10    Platform, TaskChannel, ThreadState,
11};
12
13use crate::{
14    allocators::LinearAllocator,
15    collections::{channel, Queue, RingAllocationMetadata, RingBox, RingBuffer},
16};
17
18/// The maximum amount of threads which can be used by [`parallelize`].
19/// [`create_thread_pool`] also caps the amount of threads it creates at this.
20pub const MAX_THREADS: usize = 128;
21
22/// Creates a thread pool, reserving space for buffering `task_queue_length`
23/// tasks per thread.
24///
25/// The task queue lengths are relevant in that they limit how many
26/// [`ThreadPool::spawn_task`] calls can be made before
27/// [`ThreadPool::join_task`] needs to be called to free up space in the queue.
28/// [`parallelize`] only requires 1, as it only allocates one task per
29/// thread, and requires the thread pool to be passed in empty.
30pub fn create_thread_pool(
31    allocator: &'static LinearAllocator,
32    platform: &dyn Platform,
33    task_queue_length: usize,
34) -> Option<ThreadPool> {
35    profiling::function_scope!();
36
37    let thread_count = platform.available_parallelism().min(MAX_THREADS);
38    if thread_count > 1 {
39        let init_thread_state = || {
40            let task_channel: TaskChannel = channel(platform, allocator, task_queue_length)?;
41            let result_channel: TaskChannel = channel(platform, allocator, task_queue_length)?;
42            Some(platform.spawn_pool_thread([task_channel, result_channel]))
43        };
44        let threads = allocator.try_alloc_boxed_slice_with(init_thread_state, thread_count)?;
45        Some(ThreadPool::new(threads).unwrap())
46    } else {
47        let init_thread_state = || {
48            let (tx, rx): TaskChannel = channel(platform, allocator, task_queue_length)?;
49            Some(ThreadState::new(tx, rx))
50        };
51        let threads = allocator.try_alloc_boxed_slice_with(init_thread_state, 1)?;
52        Some(ThreadPool::new(threads).unwrap())
53    }
54}
55
56/// Runs the function on multiple threads, splitting the data into one part for
57/// each thread.
58///
59/// The function also gets the offset of the specific subslice it got, relative
60/// to the start of `data`.
61///
62/// The return value is the size of the chunks `data` was split into. The same
63/// slices can be acquired by calling `chunks` or `chunks_mut` on `data` and
64/// passing it in as the chunk size. If the input slice is empty, 0 is returned.
65///
66/// ### Panics
67///
68/// If the thread pool already has pending tasks. This shouldn't ever be the
69/// case when using the threadpool with just this function, as this function
70/// always consumes all tasks it spawns.
71#[track_caller]
72pub fn parallelize<T, F>(thread_pool: &mut ThreadPool, data: &mut [T], func: F) -> usize
73where
74    T: Sync,
75    F: Sync + Fn(&mut [T], usize),
76{
77    profiling::function_scope!();
78
79    struct Task {
80        data_ptr: *mut (),
81        data_len: usize,
82        func: *const (),
83        data_offset: usize,
84    }
85
86    struct TaskProxy {
87        handle: TaskHandle<Task>,
88        metadata: RingAllocationMetadata,
89    }
90
91    if thread_pool.has_pending() {
92        panic!("thread pool has pending tasks but was used in a parallellize() call");
93    }
94
95    if data.is_empty() {
96        return 0;
97    }
98
99    let max_tasks = thread_pool.thread_count().min(MAX_THREADS);
100
101    let mut backing_task_buffer = ArrayVec::<MaybeUninit<Task>, MAX_THREADS>::new();
102    let mut backing_task_proxies = ArrayVec::<MaybeUninit<TaskProxy>, MAX_THREADS>::new();
103    for _ in 0..max_tasks {
104        backing_task_buffer.push(MaybeUninit::uninit());
105        backing_task_proxies.push(MaybeUninit::uninit());
106    }
107
108    // Safety: all allocations from this buffer are passed into the thread pool,
109    // from which all tasks are joined, and those buffers are freed right after.
110    // So there are no leaked allocations.
111    let mut task_buffer = unsafe { RingBuffer::from_mut(&mut backing_task_buffer) };
112    let mut task_proxies = Queue::from_mut(&mut backing_task_proxies).unwrap();
113
114    thread_pool.reset_thread_counter();
115
116    // Shadow `func` to ensure that the value doesn't get dropped until the end
117    // of this function, since this borrow is shared with the threads.
118    let func: *const F = &func;
119
120    let chunk_size = data.len().div_ceil(max_tasks);
121    for (i, data_part) in data.chunks_mut(chunk_size).enumerate() {
122        profiling::scope!("send task");
123        // Shouldn't ever trip, but if it does, we'd much rather crash here than
124        // having half-spawned a task, which could be unsound.
125        assert!(i < max_tasks);
126
127        // Allocate the thread pool task.
128        let data_ptr: *mut T = data_part.as_mut_ptr();
129        let data_len: usize = data_part.len();
130        let (task, metadata) = task_buffer
131            .allocate_box(Task {
132                data_ptr: data_ptr as *mut (),
133                data_len,
134                func: func as *const (),
135                data_offset: i * chunk_size,
136            })
137            .ok()
138            .unwrap() // does not panic: task_buffer is guaranteed to have capacity via the assert at the start of this loop body
139            .into_parts();
140
141        // Send off the task, using the proxy function from it to call the
142        // user-provided one.
143        let handle = thread_pool
144            .spawn_task(task, |task| {
145                let data_ptr = task.data_ptr as *mut T;
146                let data_len = task.data_len;
147                // Safety:
148                // - Type, pointer and length validity-wise, this slice is ok to
149                //   create as it was created from a slice of T in the first
150                //   place.
151                // - Lifetime-wise, creating this slice is valid because the
152                //   slice's lifetime spans this function, and this function is
153                //   run within the lifetime of the `parallelize` function call
154                //   due to all tasks being joined before the end, and the
155                //   original slice is valid for the entirety of `parallellize`.
156                // - Exclusive-access-wise, it's valid since the backing slice
157                //   is only used to split it with chunks_mut, and those chunks
158                //   are simply sent off to worker threads. Since this all
159                //   happens during parallelize() (see lifetime point), there's
160                //   definitely no others creating any kind of borrow of this
161                //   particular chunk.
162                let data: &mut [T] = unsafe { slice::from_raw_parts_mut(data_ptr, data_len) };
163                let func = task.func as *const F;
164                // Safety: same logic as for the data, except that this
165                // reference is shared, which is valid because it's a
166                // const-pointer and we borrow it immutably.
167                unsafe { (*func)(data, task.data_offset) };
168            })
169            .ok()
170            .unwrap(); // does not panic: thread_pool is guaranteed to have capacity, it's empty and we're only spawning thread_count tasks
171
172        // Add the task handle to the queue to be joined before returning.
173        task_proxies
174            .push_back(TaskProxy { handle, metadata })
175            .ok()
176            .unwrap(); // does not panic: task_proxies is guaranteed to have capacity via the assert at the start of this loop body
177    }
178
179    // Join tasks and free the buffers (doesn't free up space for anything, but
180    // makes sure we're not leaking anything, which would violate the safety
181    // requirements of the non-static RingBuffer).
182    while let Some(proxy) = task_proxies.pop_front() {
183        profiling::scope!("receive result");
184        let task = thread_pool.join_task(proxy.handle).ok().unwrap(); // does not panic: we're joining tasks in FIFO order
185
186        // Safety: the `Task` was allocated in the previous loop, with the
187        // actual boxed task being sent onto a thread, and the metadata stored
188        // in the proxy, alongside the handle for said task. Since `task` here
189        // is the result of that task, it must be the same boxed task allocated
190        // alongside this metadata.
191        let boxed = unsafe { RingBox::from_parts(task, proxy.metadata) };
192        task_buffer.free_box(boxed).ok().unwrap();
193    }
194
195    chunk_size
196}
197
198#[cfg(test)]
199mod tests {
200    use super::{create_thread_pool, parallelize};
201    use crate::{
202        allocators::{static_allocator, LinearAllocator},
203        test_platform::TestPlatform,
204    };
205
206    #[test]
207    fn parallelize_works_singlethreaded() {
208        static ARENA: &LinearAllocator = static_allocator!(10_000);
209        let platform = TestPlatform::new(false);
210        let mut thread_pool = create_thread_pool(ARENA, &platform, 1).unwrap();
211
212        let mut data = [1, 2, 3, 4];
213        parallelize(&mut thread_pool, &mut data, |data, _| {
214            for n in data {
215                *n *= *n;
216            }
217        });
218        assert_eq!([1, 4, 9, 16], data);
219    }
220
221    #[test]
222    #[cfg(not(target_os = "emscripten"))]
223    fn parallelize_works_multithreaded() {
224        static ARENA: &LinearAllocator = static_allocator!(10_000);
225        let platform = TestPlatform::new(true);
226        let mut thread_pool = create_thread_pool(ARENA, &platform, 1).unwrap();
227
228        let mut data = [1, 2, 3, 4];
229        parallelize(&mut thread_pool, &mut data, |data, _| {
230            for n in data {
231                *n *= *n;
232            }
233        });
234        assert_eq!([1, 4, 9, 16], data);
235    }
236
237    #[test]
238    #[ignore = "the emscripten target doesn't support multithreading"]
239    #[cfg(target_os = "emscripten")]
240    fn parallelize_works_multithreaded() {}
241}