engine/multithreading.rs
1// SPDX-FileCopyrightText: 2025 Jens Pitkänen <jens.pitkanen@helsinki.fi>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5use core::{mem::MaybeUninit, slice};
6
7use arrayvec::ArrayVec;
8use platform::{
9 thread_pool::{TaskHandle, ThreadPool},
10 Platform, TaskChannel, ThreadState,
11};
12
13use crate::{
14 allocators::LinearAllocator,
15 collections::{channel, Queue, RingAllocationMetadata, RingBox, RingBuffer},
16};
17
18/// The maximum amount of threads which can be used by [`parallelize`].
19/// [`create_thread_pool`] also caps the amount of threads it creates at this.
20pub const MAX_THREADS: usize = 128;
21
22/// Creates a thread pool, reserving space for buffering `task_queue_length`
23/// tasks per thread.
24///
25/// The task queue lengths are relevant in that they limit how many
26/// [`ThreadPool::spawn_task`] calls can be made before
27/// [`ThreadPool::join_task`] needs to be called to free up space in the queue.
28/// [`parallelize`] only requires 1, as it only allocates one task per
29/// thread, and requires the thread pool to be passed in empty.
30pub fn create_thread_pool(
31 allocator: &'static LinearAllocator,
32 platform: &dyn Platform,
33 task_queue_length: usize,
34) -> Option<ThreadPool> {
35 profiling::function_scope!();
36
37 let thread_count = platform.available_parallelism().min(MAX_THREADS);
38 if thread_count > 1 {
39 let init_thread_state = || {
40 let task_channel: TaskChannel = channel(platform, allocator, task_queue_length)?;
41 let result_channel: TaskChannel = channel(platform, allocator, task_queue_length)?;
42 Some(platform.spawn_pool_thread([task_channel, result_channel]))
43 };
44 let threads = allocator.try_alloc_boxed_slice_with(init_thread_state, thread_count)?;
45 Some(ThreadPool::new(threads).unwrap())
46 } else {
47 let init_thread_state = || {
48 let (tx, rx): TaskChannel = channel(platform, allocator, task_queue_length)?;
49 Some(ThreadState::new(tx, rx))
50 };
51 let threads = allocator.try_alloc_boxed_slice_with(init_thread_state, 1)?;
52 Some(ThreadPool::new(threads).unwrap())
53 }
54}
55
56/// Runs the function on multiple threads, splitting the data into one part for
57/// each thread.
58///
59/// The function also gets the offset of the specific subslice it got, relative
60/// to the start of `data`.
61///
62/// The return value is the size of the chunks `data` was split into. The same
63/// slices can be acquired by calling `chunks` or `chunks_mut` on `data` and
64/// passing it in as the chunk size. If the input slice is empty, 0 is returned.
65///
66/// ### Panics
67///
68/// If the thread pool already has pending tasks. This shouldn't ever be the
69/// case when using the threadpool with just this function, as this function
70/// always consumes all tasks it spawns.
71#[track_caller]
72pub fn parallelize<T, F>(thread_pool: &mut ThreadPool, data: &mut [T], func: F) -> usize
73where
74 T: Sync,
75 F: Sync + Fn(&mut [T], usize),
76{
77 profiling::function_scope!();
78
79 struct Task {
80 data_ptr: *mut (),
81 data_len: usize,
82 func: *const (),
83 data_offset: usize,
84 }
85
86 struct TaskProxy {
87 handle: TaskHandle<Task>,
88 metadata: RingAllocationMetadata,
89 }
90
91 if thread_pool.has_pending() {
92 panic!("thread pool has pending tasks but was used in a parallellize() call");
93 }
94
95 if data.is_empty() {
96 return 0;
97 }
98
99 let max_tasks = thread_pool.thread_count().min(MAX_THREADS);
100
101 let mut backing_task_buffer = ArrayVec::<MaybeUninit<Task>, MAX_THREADS>::new();
102 let mut backing_task_proxies = ArrayVec::<MaybeUninit<TaskProxy>, MAX_THREADS>::new();
103 for _ in 0..max_tasks {
104 backing_task_buffer.push(MaybeUninit::uninit());
105 backing_task_proxies.push(MaybeUninit::uninit());
106 }
107
108 // Safety: all allocations from this buffer are passed into the thread pool,
109 // from which all tasks are joined, and those buffers are freed right after.
110 // So there are no leaked allocations.
111 let mut task_buffer = unsafe { RingBuffer::from_mut(&mut backing_task_buffer) };
112 let mut task_proxies = Queue::from_mut(&mut backing_task_proxies).unwrap();
113
114 thread_pool.reset_thread_counter();
115
116 // Shadow `func` to ensure that the value doesn't get dropped until the end
117 // of this function, since this borrow is shared with the threads.
118 let func: *const F = &func;
119
120 let chunk_size = data.len().div_ceil(max_tasks);
121 for (i, data_part) in data.chunks_mut(chunk_size).enumerate() {
122 profiling::scope!("send task");
123 // Shouldn't ever trip, but if it does, we'd much rather crash here than
124 // having half-spawned a task, which could be unsound.
125 assert!(i < max_tasks);
126
127 // Allocate the thread pool task.
128 let data_ptr: *mut T = data_part.as_mut_ptr();
129 let data_len: usize = data_part.len();
130 let (task, metadata) = task_buffer
131 .allocate_box(Task {
132 data_ptr: data_ptr as *mut (),
133 data_len,
134 func: func as *const (),
135 data_offset: i * chunk_size,
136 })
137 .ok()
138 .unwrap() // does not panic: task_buffer is guaranteed to have capacity via the assert at the start of this loop body
139 .into_parts();
140
141 // Send off the task, using the proxy function from it to call the
142 // user-provided one.
143 let handle = thread_pool
144 .spawn_task(task, |task| {
145 let data_ptr = task.data_ptr as *mut T;
146 let data_len = task.data_len;
147 // Safety:
148 // - Type, pointer and length validity-wise, this slice is ok to
149 // create as it was created from a slice of T in the first
150 // place.
151 // - Lifetime-wise, creating this slice is valid because the
152 // slice's lifetime spans this function, and this function is
153 // run within the lifetime of the `parallelize` function call
154 // due to all tasks being joined before the end, and the
155 // original slice is valid for the entirety of `parallellize`.
156 // - Exclusive-access-wise, it's valid since the backing slice
157 // is only used to split it with chunks_mut, and those chunks
158 // are simply sent off to worker threads. Since this all
159 // happens during parallelize() (see lifetime point), there's
160 // definitely no others creating any kind of borrow of this
161 // particular chunk.
162 let data: &mut [T] = unsafe { slice::from_raw_parts_mut(data_ptr, data_len) };
163 let func = task.func as *const F;
164 // Safety: same logic as for the data, except that this
165 // reference is shared, which is valid because it's a
166 // const-pointer and we borrow it immutably.
167 unsafe { (*func)(data, task.data_offset) };
168 })
169 .ok()
170 .unwrap(); // does not panic: thread_pool is guaranteed to have capacity, it's empty and we're only spawning thread_count tasks
171
172 // Add the task handle to the queue to be joined before returning.
173 task_proxies
174 .push_back(TaskProxy { handle, metadata })
175 .ok()
176 .unwrap(); // does not panic: task_proxies is guaranteed to have capacity via the assert at the start of this loop body
177 }
178
179 // Join tasks and free the buffers (doesn't free up space for anything, but
180 // makes sure we're not leaking anything, which would violate the safety
181 // requirements of the non-static RingBuffer).
182 while let Some(proxy) = task_proxies.pop_front() {
183 profiling::scope!("receive result");
184 let task = thread_pool.join_task(proxy.handle).ok().unwrap(); // does not panic: we're joining tasks in FIFO order
185
186 // Safety: the `Task` was allocated in the previous loop, with the
187 // actual boxed task being sent onto a thread, and the metadata stored
188 // in the proxy, alongside the handle for said task. Since `task` here
189 // is the result of that task, it must be the same boxed task allocated
190 // alongside this metadata.
191 let boxed = unsafe { RingBox::from_parts(task, proxy.metadata) };
192 task_buffer.free_box(boxed).ok().unwrap();
193 }
194
195 chunk_size
196}
197
198#[cfg(test)]
199mod tests {
200 use super::{create_thread_pool, parallelize};
201 use crate::{
202 allocators::{static_allocator, LinearAllocator},
203 test_platform::TestPlatform,
204 };
205
206 #[test]
207 fn parallelize_works_singlethreaded() {
208 static ARENA: &LinearAllocator = static_allocator!(10_000);
209 let platform = TestPlatform::new(false);
210 let mut thread_pool = create_thread_pool(ARENA, &platform, 1).unwrap();
211
212 let mut data = [1, 2, 3, 4];
213 parallelize(&mut thread_pool, &mut data, |data, _| {
214 for n in data {
215 *n *= *n;
216 }
217 });
218 assert_eq!([1, 4, 9, 16], data);
219 }
220
221 #[test]
222 #[cfg(not(target_os = "emscripten"))]
223 fn parallelize_works_multithreaded() {
224 static ARENA: &LinearAllocator = static_allocator!(10_000);
225 let platform = TestPlatform::new(true);
226 let mut thread_pool = create_thread_pool(ARENA, &platform, 1).unwrap();
227
228 let mut data = [1, 2, 3, 4];
229 parallelize(&mut thread_pool, &mut data, |data, _| {
230 for n in data {
231 *n *= *n;
232 }
233 });
234 assert_eq!([1, 4, 9, 16], data);
235 }
236
237 #[test]
238 #[ignore = "the emscripten target doesn't support multithreading"]
239 #[cfg(target_os = "emscripten")]
240 fn parallelize_works_multithreaded() {}
241}