binius_compute/
alloc.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::sync::Mutex;
4
5use binius_utils::checked_arithmetics::checked_log_2;
6
7use super::memory::{ComputeMemory, SizedSlice};
8use crate::cpu::CpuMemory;
9
10pub trait ComputeAllocator<F, Mem: ComputeMemory<F>> {
11	/// Allocates a slice of elements.
12	///
13	/// This method operates on an immutable self reference so that multiple allocator references
14	/// can co-exist. This follows how the `bumpalo` crate's `Bump` interface works. It may not be
15	/// necessary actually (since this partitions a borrowed slice, whereas `Bump` owns its memory).
16	///
17	/// ## Pre-conditions
18	///
19	/// - `n` must be a multiple of `Mem::ALIGNMENT`
20	fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error>;
21
22	/// Returns the remaining number of elements that can be allocated.
23	fn capacity(&self) -> usize;
24
25	/// Returns the remaining unallocated capacity as a new allocator with a limited scope.
26	fn subscope_allocator(&mut self) -> impl ComputeAllocator<F, Mem>;
27}
28
29/// Basic bump allocator that allocates slices from an underlying memory buffer provided at
30/// construction.
31pub struct BumpAllocator<'a, F, Mem: ComputeMemory<F>> {
32	buffer: Mutex<Option<Mem::FSliceMut<'a>>>,
33}
34
35impl<'a, F, Mem> BumpAllocator<'a, F, Mem>
36where
37	F: 'static,
38	Mem: ComputeMemory<F>,
39{
40	pub fn new(buffer: Mem::FSliceMut<'a>) -> Self {
41		Self {
42			buffer: Mutex::new(Some(buffer)),
43		}
44	}
45
46	pub fn from_ref<'b>(buffer: &'b mut Mem::FSliceMut<'a>) -> BumpAllocator<'b, F, Mem> {
47		let buffer = Mem::slice_mut(buffer, ..);
48		BumpAllocator {
49			buffer: Mutex::new(Some(buffer)),
50		}
51	}
52
53	fn remaining(&mut self) -> Mem::FSliceMut<'_> {
54		Mem::to_owned_mut(
55			self.buffer
56				.get_mut()
57				.expect("mutex is always available")
58				.as_mut()
59				.expect("buffer is always Some by invariant"),
60		)
61	}
62}
63
64impl<'a, F, Mem: ComputeMemory<F>> ComputeAllocator<F, Mem> for BumpAllocator<'a, F, Mem>
65where
66	F: 'static,
67{
68	fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error> {
69		let mut buffer_lock = self.buffer.lock().expect("mutex is always available");
70
71		let buffer = buffer_lock
72			.take()
73			.expect("buffer is always Some by invariant");
74		// buffer temporarily contains None
75		if buffer.len() < n {
76			*buffer_lock = Some(buffer);
77			// buffer contains Some, invariant restored
78			Err(Error::OutOfMemory)
79		} else {
80			let (mut lhs, rhs) = Mem::split_at_mut(buffer, n.max(Mem::ALIGNMENT));
81			if n < Mem::ALIGNMENT {
82				assert!(n.is_power_of_two(), "n must be a power of two");
83				for _ in checked_log_2(n)..checked_log_2(Mem::ALIGNMENT) {
84					(lhs, _) = Mem::split_half_mut(lhs)
85				}
86			}
87			*buffer_lock = Some(rhs);
88			// buffer contains Some, invariant restored
89			Ok(Mem::narrow_mut(lhs))
90		}
91	}
92
93	fn capacity(&self) -> usize {
94		self.buffer
95			.lock()
96			.expect("mutex is always available")
97			.as_ref()
98			.expect("buffer is always Some by invariant")
99			.len()
100	}
101
102	fn subscope_allocator(&mut self) -> impl ComputeAllocator<F, Mem> {
103		BumpAllocator::<F, Mem>::new(self.remaining())
104	}
105}
106
107/// Alias for a bump allocator over CPU host memory.
108pub type HostBumpAllocator<'a, F> = BumpAllocator<'a, F, CpuMemory>;
109
110#[derive(Debug, thiserror::Error)]
111pub enum Error {
112	#[error("allocator is out of memory")]
113	OutOfMemory,
114}
115
116#[cfg(test)]
117mod tests {
118	use assert_matches::assert_matches;
119
120	use super::*;
121	use crate::cpu::memory::CpuMemory;
122
123	#[test]
124	fn test_alloc() {
125		let mut data = (0..256u128).collect::<Vec<_>>();
126
127		{
128			let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
129			assert_eq!(bump.alloc(100).unwrap().len(), 100);
130			assert_eq!(bump.alloc(100).unwrap().len(), 100);
131			assert_matches!(bump.alloc(100), Err(Error::OutOfMemory));
132			// Release memory all at once.
133		}
134
135		// Reuse memory
136		let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
137		let data = bump.alloc(100).unwrap();
138		assert_eq!(data.len(), 100);
139	}
140
141	#[test]
142	fn test_stack_alloc() {
143		let mut data = (0..256u128).collect::<Vec<_>>();
144		let mut bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
145		assert_eq!(bump.alloc(100).unwrap().len(), 100);
146		assert_matches!(bump.alloc(200), Err(Error::OutOfMemory));
147
148		{
149			let next_layer_memory = bump.remaining();
150			let bump2 = BumpAllocator::<u128, CpuMemory>::new(next_layer_memory);
151			let _ = bump2.alloc(100).unwrap();
152			assert_matches!(bump2.alloc(57), Err(Error::OutOfMemory));
153			let _ = bump2.alloc(56).unwrap();
154		}
155
156		let _ = bump.alloc(100).unwrap();
157	}
158}