1use std::cell::Cell;
4
5use binius_utils::checked_arithmetics::checked_log_2;
6
7use super::memory::{ComputeMemory, SizedSlice};
8use crate::cpu::CpuMemory;
9
10pub trait ComputeAllocator<F, Mem: ComputeMemory<F>> {
11 fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error>;
21
22 fn remaining(&mut self) -> Mem::FSliceMut<'_>;
27
28 fn capacity(&self) -> usize;
30}
31
32pub struct BumpAllocator<'a, F, Mem: ComputeMemory<F>> {
35 buffer: Cell<Option<Mem::FSliceMut<'a>>>,
36}
37
38impl<'a, F, Mem> BumpAllocator<'a, F, Mem>
39where
40 F: 'static,
41 Mem: ComputeMemory<F> + 'a,
42{
43 pub fn new(buffer: Mem::FSliceMut<'a>) -> Self {
44 Self {
45 buffer: Cell::new(Some(buffer)),
46 }
47 }
48
49 pub fn from_ref<'b>(buffer: &'b mut Mem::FSliceMut<'a>) -> BumpAllocator<'b, F, Mem> {
50 let buffer = Mem::slice_mut(buffer, ..);
51 BumpAllocator {
52 buffer: Cell::new(Some(buffer)),
53 }
54 }
55}
56
57impl<'a, F, Mem: ComputeMemory<F>> ComputeAllocator<F, Mem> for BumpAllocator<'a, F, Mem> {
58 fn alloc(&self, n: usize) -> Result<Mem::FSliceMut<'_>, Error> {
59 let buffer = self
60 .buffer
61 .take()
62 .expect("buffer is always Some by invariant");
63 if buffer.len() < n {
65 self.buffer.set(Some(buffer));
66 Err(Error::OutOfMemory)
68 } else {
69 let (mut lhs, rhs) = Mem::split_at_mut(buffer, n.max(Mem::ALIGNMENT));
70 if n < Mem::ALIGNMENT {
71 assert!(n.is_power_of_two(), "n must be a power of two");
72 for _ in checked_log_2(n)..checked_log_2(Mem::ALIGNMENT) {
73 (lhs, _) = Mem::split_half_mut(lhs)
74 }
75 }
76 self.buffer.set(Some(rhs));
77 Ok(Mem::narrow_mut(lhs))
79 }
80 }
81
82 fn remaining(&mut self) -> Mem::FSliceMut<'_> {
83 Mem::to_owned_mut(
84 self.buffer
85 .get_mut()
86 .as_mut()
87 .expect("buffer is always Some by invariant"),
88 )
89 }
90
91 fn capacity(&self) -> usize {
92 let buffer = self
93 .buffer
94 .take()
95 .expect("buffer is always Some by invariant");
96 let ret = buffer.len();
97 self.buffer.set(Some(buffer));
98 ret
99 }
100}
101
102pub type HostBumpAllocator<'a, F> = BumpAllocator<'a, F, CpuMemory>;
104
105#[derive(Debug, thiserror::Error)]
106pub enum Error {
107 #[error("allocator is out of memory")]
108 OutOfMemory,
109}
110
111#[cfg(test)]
112mod tests {
113 use assert_matches::assert_matches;
114
115 use super::*;
116 use crate::cpu::memory::CpuMemory;
117
118 #[test]
119 fn test_alloc() {
120 let mut data = (0..256u128).collect::<Vec<_>>();
121
122 {
123 let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
124 assert_eq!(bump.alloc(100).unwrap().len(), 100);
125 assert_eq!(bump.alloc(100).unwrap().len(), 100);
126 assert_matches!(bump.alloc(100), Err(Error::OutOfMemory));
127 }
129
130 let bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
132 let data = bump.alloc(100).unwrap();
133 assert_eq!(data.len(), 100);
134 }
135
136 #[test]
137 fn test_stack_alloc() {
138 let mut data = (0..256u128).collect::<Vec<_>>();
139 let mut bump = BumpAllocator::<u128, CpuMemory>::new(&mut data);
140 assert_eq!(bump.alloc(100).unwrap().len(), 100);
141 assert_matches!(bump.alloc(200), Err(Error::OutOfMemory));
142
143 {
144 let next_layer_memory = bump.remaining();
145 let bump2 = BumpAllocator::<u128, CpuMemory>::new(next_layer_memory);
146 let _ = bump2.alloc(100).unwrap();
147 assert_matches!(bump2.alloc(57), Err(Error::OutOfMemory));
148 let _ = bump2.alloc(56).unwrap();
149 }
150
151 let _ = bump.alloc(100).unwrap();
152 }
153}