binius_compute/
alloc.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::cell::Cell;
4
5use binius_utils::checked_arithmetics::checked_log_2;
6
7use super::memory::{ComputeMemory, SizedSlice};
8use crate::cpu::CpuMemory;
9
10pub trait ComputeAllocator<F, Mem: ComputeMemory<F>> {
11	/// Allocates a slice of elements.
12	///
13	/// This method operates on an immutable self reference so that multiple allocator references
14	/// can co-exist. This follows how the `bumpalo` crate's `Bump` interface works. It may not be
15	/// necessary actually (since this partitions a borrowed slice, whereas `Bump` owns its memory).
16	///
17	/// ## Pre-conditions
18	///
19	/// - `n` must be a multiple of `Mem::ALIGNMENT`
20	fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error>;
21
22	/// Borrow the remaining unallocated capacity.
23	///
24	/// This allows another allocator to have unique mutable access to the rest of the elements in
25	/// this allocator until it gets dropped, at which point this allocator can be used again
26	fn remaining(&mut self) -> Mem::FSliceMut<'_>;
27
28	/// Returns the remaining number of elements that can be allocated.
29	fn capacity(&self) -> usize;
30}
31
32/// Basic bump allocator that allocates slices from an underlying memory buffer provided at
33/// construction.
34pub struct BumpAllocator<'a, F, Mem: ComputeMemory<F>> {
35	buffer: Cell<Option<Mem::FSliceMut<'a>>>,
36}
37
38impl<'a, F, Mem> BumpAllocator<'a, F, Mem>
39where
40	F: 'static,
41	Mem: ComputeMemory<F> + 'a,
42{
43	pub fn new(buffer: Mem::FSliceMut<'a>) -> Self {
44		Self {
45			buffer: Cell::new(Some(buffer)),
46		}
47	}
48
49	pub fn from_ref<'b>(buffer: &'b mut Mem::FSliceMut<'a>) -> BumpAllocator<'b, F, Mem> {
50		let buffer = Mem::slice_mut(buffer, ..);
51		BumpAllocator {
52			buffer: Cell::new(Some(buffer)),
53		}
54	}
55}
56
57impl<'a, F, Mem: ComputeMemory<F>> ComputeAllocator<F, Mem> for BumpAllocator<'a, F, Mem> {
58	fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error> {
59		let buffer = self
60			.buffer
61			.take()
62			.expect("buffer is always Some by invariant");
63		// buffer temporarily contains None
64		if buffer.len() < n {
65			self.buffer.set(Some(buffer));
66			// buffer contains Some, invariant restored
67			Err(Error::OutOfMemory)
68		} else {
69			let (mut lhs, rhs) = Mem::split_at_mut(buffer, n.max(Mem::ALIGNMENT));
70			if n < Mem::ALIGNMENT {
71				assert!(n.is_power_of_two(), "n must be a power of two");
72				for _ in checked_log_2(n)..checked_log_2(Mem::ALIGNMENT) {
73					(lhs, _) = Mem::split_half_mut(lhs)
74				}
75			}
76			self.buffer.set(Some(rhs));
77			// buffer contains Some, invariant restored
78			Ok(Mem::narrow_mut(lhs))
79		}
80	}
81
82	fn remaining(&mut self) -> Mem::FSliceMut<'_> {
83		Mem::to_owned_mut(
84			self.buffer
85				.get_mut()
86				.as_mut()
87				.expect("buffer is always Some by invariant"),
88		)
89	}
90
91	fn capacity(&self) -> usize {
92		let buffer = self
93			.buffer
94			.take()
95			.expect("buffer is always Some by invariant");
96		let ret = buffer.len();
97		self.buffer.set(Some(buffer));
98		ret
99	}
100}
101
102/// Alias for a bump allocator over CPU host memory.
103pub type HostBumpAllocator<'a, F> = BumpAllocator<'a, F, CpuMemory>;
104
105#[derive(Debug, thiserror::Error)]
106pub enum Error {
107	#[error("allocator is out of memory")]
108	OutOfMemory,
109}
110
111#[cfg(test)]
112mod tests {
113	use assert_matches::assert_matches;
114
115	use super::*;
116	use crate::cpu::memory::CpuMemory;
117
118	#[test]
119	fn test_alloc() {
120		let mut data = (0..256u128).collect::<Vec<_>>();
121
122		{
123			let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
124			assert_eq!(bump.alloc(100).unwrap().len(), 100);
125			assert_eq!(bump.alloc(100).unwrap().len(), 100);
126			assert_matches!(bump.alloc(100), Err(Error::OutOfMemory));
127			// Release memory all at once.
128		}
129
130		// Reuse memory
131		let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
132		let data = bump.alloc(100).unwrap();
133		assert_eq!(data.len(), 100);
134	}
135
136	#[test]
137	fn test_stack_alloc() {
138		let mut data = (0..256u128).collect::<Vec<_>>();
139		let mut bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
140		assert_eq!(bump.alloc(100).unwrap().len(), 100);
141		assert_matches!(bump.alloc(200), Err(Error::OutOfMemory));
142
143		{
144			let next_layer_memory = bump.remaining();
145			let bump2 = BumpAllocator::<u128, CpuMemory>::new(next_layer_memory);
146			let _ = bump2.alloc(100).unwrap();
147			assert_matches!(bump2.alloc(57), Err(Error::OutOfMemory));
148			let _ = bump2.alloc(56).unwrap();
149		}
150
151		let _ = bump.alloc(100).unwrap();
152	}
153}