1use std::sync::Mutex;
4
5use binius_utils::checked_arithmetics::checked_log_2;
6
7use super::memory::{ComputeMemory, SizedSlice};
8use crate::cpu::CpuMemory;
9
10pub trait ComputeAllocator<F, Mem: ComputeMemory<F>> {
11 fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error>;
21
22 fn capacity(&self) -> usize;
24
25 fn subscope_allocator(&mut self) -> impl ComputeAllocator<F, Mem>;
27}
28
29pub struct BumpAllocator<'a, F, Mem: ComputeMemory<F>> {
32 buffer: Mutex<Option<Mem::FSliceMut<'a>>>,
33}
34
35impl<'a, F, Mem> BumpAllocator<'a, F, Mem>
36where
37 F: 'static,
38 Mem: ComputeMemory<F>,
39{
40 pub fn new(buffer: Mem::FSliceMut<'a>) -> Self {
41 Self {
42 buffer: Mutex::new(Some(buffer)),
43 }
44 }
45
46 pub fn from_ref<'b>(buffer: &'b mut Mem::FSliceMut<'a>) -> BumpAllocator<'b, F, Mem> {
47 let buffer = Mem::slice_mut(buffer, ..);
48 BumpAllocator {
49 buffer: Mutex::new(Some(buffer)),
50 }
51 }
52
53 fn remaining(&mut self) -> Mem::FSliceMut<'_> {
54 Mem::to_owned_mut(
55 self.buffer
56 .get_mut()
57 .expect("mutex is always available")
58 .as_mut()
59 .expect("buffer is always Some by invariant"),
60 )
61 }
62}
63
64impl<'a, F, Mem: ComputeMemory<F>> ComputeAllocator<F, Mem> for BumpAllocator<'a, F, Mem>
65where
66 F: 'static,
67{
68 fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error> {
69 let mut buffer_lock = self.buffer.lock().expect("mutex is always available");
70
71 let buffer = buffer_lock
72 .take()
73 .expect("buffer is always Some by invariant");
74 if buffer.len() < n {
76 *buffer_lock = Some(buffer);
77 Err(Error::OutOfMemory)
79 } else {
80 let (mut lhs, rhs) = Mem::split_at_mut(buffer, n.max(Mem::ALIGNMENT));
81 if n < Mem::ALIGNMENT {
82 assert!(n.is_power_of_two(), "n must be a power of two");
83 for _ in checked_log_2(n)..checked_log_2(Mem::ALIGNMENT) {
84 (lhs, _) = Mem::split_half_mut(lhs)
85 }
86 }
87 *buffer_lock = Some(rhs);
88 Ok(Mem::narrow_mut(lhs))
90 }
91 }
92
93 fn capacity(&self) -> usize {
94 self.buffer
95 .lock()
96 .expect("mutex is always available")
97 .as_ref()
98 .expect("buffer is always Some by invariant")
99 .len()
100 }
101
102 fn subscope_allocator(&mut self) -> impl ComputeAllocator<F, Mem> {
103 BumpAllocator::<F, Mem>::new(self.remaining())
104 }
105}
106
107pub type HostBumpAllocator<'a, F> = BumpAllocator<'a, F, CpuMemory>;
109
110#[derive(Debug, thiserror::Error)]
111pub enum Error {
112 #[error("allocator is out of memory")]
113 OutOfMemory,
114}
115
116#[cfg(test)]
117mod tests {
118 use assert_matches::assert_matches;
119
120 use super::*;
121 use crate::cpu::memory::CpuMemory;
122
123 #[test]
124 fn test_alloc() {
125 let mut data = (0..256u128).collect::<Vec<_>>();
126
127 {
128 let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
129 assert_eq!(bump.alloc(100).unwrap().len(), 100);
130 assert_eq!(bump.alloc(100).unwrap().len(), 100);
131 assert_matches!(bump.alloc(100), Err(Error::OutOfMemory));
132 }
134
135 let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
137 let data = bump.alloc(100).unwrap();
138 assert_eq!(data.len(), 100);
139 }
140
141 #[test]
142 fn test_stack_alloc() {
143 let mut data = (0..256u128).collect::<Vec<_>>();
144 let mut bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
145 assert_eq!(bump.alloc(100).unwrap().len(), 100);
146 assert_matches!(bump.alloc(200), Err(Error::OutOfMemory));
147
148 {
149 let next_layer_memory = bump.remaining();
150 let bump2 = BumpAllocator::<u128, CpuMemory>::new(next_layer_memory);
151 let _ = bump2.alloc(100).unwrap();
152 assert_matches!(bump2.alloc(57), Err(Error::OutOfMemory));
153 let _ = bump2.alloc(56).unwrap();
154 }
155
156 let _ = bump.alloc(100).unwrap();
157 }
158}