binius_compute/memory.rs
1// Copyright 2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::RangeBounds};
4
5use binius_field::TowerField;
6use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
7
8pub trait SizedSlice {
9 fn is_empty(&self) -> bool {
10 self.len() == 0
11 }
12
13 fn len(&self) -> usize;
14}
15
16impl<T> SizedSlice for &[T] {
17 fn len(&self) -> usize {
18 (**self).len()
19 }
20}
21
22impl<T> SizedSlice for &mut [T] {
23 fn len(&self) -> usize {
24 (**self).len()
25 }
26}
27
28/// A batch of slices of the same length.
29pub struct SlicesBatch<Slice: SizedSlice> {
30 rows: Vec<Slice>,
31 row_len: usize,
32}
33
34impl<Slice: SizedSlice> SlicesBatch<Slice> {
35 /// Creates a new batch of slices with the given length.
36 ///
37 /// # Panics
38 /// If any of the slices in `rows` does not have the specified `row_len`.
39 pub fn new(rows: Vec<Slice>, row_len: usize) -> Self {
40 for row in &rows {
41 assert_eq!(row.len(), row_len);
42 }
43
44 Self { rows, row_len }
45 }
46
47 /// Number of memory slices
48 pub fn n_rows(&self) -> usize {
49 self.rows.len()
50 }
51
52 /// Length of each memory slice
53 pub fn row_len(&self) -> usize {
54 self.row_len
55 }
56
57 /// Returns a slice of the batch at the given index.
58 pub fn row(&self, index: usize) -> &Slice {
59 &self.rows[index]
60 }
61
62 /// Returns iterator over the slices in the batch.
63 pub fn iter(&self) -> impl Iterator<Item = &Slice> {
64 self.rows.iter()
65 }
66}
67
68/// Interface for manipulating handles to memory in a compute device.
69pub trait ComputeMemory<F> {
70 /// The required alignment of indices for the split methods. This must be a power of two.
71 const ALIGNMENT: usize;
72
73 /// An opaque handle to an immutable slice of elements stored in a compute memory.
74 type FSlice<'a>: Copy + SizedSlice + Send + Sync + Debug;
75
76 /// An opaque handle to a mutable slice of elements stored in a compute memory.
77 type FSliceMut<'a>: SizedSlice + Send + Debug;
78
79 /// Borrows an immutable memory slice, narrowing the lifetime.
80 fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a>;
81
82 /// Borrows a mutable memory slice, narrowing the lifetime.
83 fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a>;
84
85 // Converts a reference to an FSliceMut to an FSliceMut
86 fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a>;
87
88 /// Borrows a mutable memory slice as immutable.
89 ///
90 /// This allows the immutable reference to be copied.
91 fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a>;
92
93 /// Converts a mutable memory slice to an immutable slice.
94 fn to_const(data: Self::FSliceMut<'_>) -> Self::FSlice<'_>;
95
96 /// Borrows a subslice of an immutable memory slice.
97 ///
98 /// ## Preconditions
99 ///
100 /// - the range bounds must be multiples of [`Self::ALIGNMENT`]
101 fn slice(data: Self::FSlice<'_>, range: impl RangeBounds<usize>) -> Self::FSlice<'_>;
102
103 /// Borrows a subslice of a mutable memory slice.
104 ///
105 /// ## Preconditions
106 ///
107 /// - the range bounds must be multiples of [`Self::ALIGNMENT`]
108 fn slice_mut<'a>(
109 data: &'a mut Self::FSliceMut<'_>,
110 range: impl RangeBounds<usize>,
111 ) -> Self::FSliceMut<'a>;
112
113 /// Splits an immutable slice into two disjoint subslices.
114 ///
115 /// ## Preconditions
116 ///
117 /// - `mid` must be a multiple of [`Self::ALIGNMENT`]
118 fn split_at(data: Self::FSlice<'_>, mid: usize) -> (Self::FSlice<'_>, Self::FSlice<'_>) {
119 let head = Self::slice(data, ..mid);
120 let tail = Self::slice(data, mid..);
121 (head, tail)
122 }
123
124 /// Splits a mutable slice into two disjoint subslices.
125 ///
126 /// ## Preconditions
127 ///
128 /// - `mid` must be a multiple of [`Self::ALIGNMENT`]
129 fn split_at_mut(
130 data: Self::FSliceMut<'_>,
131 mid: usize,
132 ) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>);
133
134 fn split_at_mut_borrowed<'a>(
135 data: &'a mut Self::FSliceMut<'_>,
136 mid: usize,
137 ) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
138 let borrowed = Self::slice_mut(data, ..);
139 Self::split_at_mut(borrowed, mid)
140 }
141
142 /// Splits slice into equal chunks.
143 ///
144 /// ## Preconditions
145 ///
146 /// - the length of the slice must be a multiple of `chunk_len`
147 /// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
148 fn slice_chunks<'a>(
149 data: Self::FSlice<'a>,
150 chunk_len: usize,
151 ) -> impl Iterator<Item = Self::FSlice<'a>> {
152 let n_chunks = checked_int_div(data.len(), chunk_len);
153 (0..n_chunks).map(move |i| Self::slice(data, i * chunk_len..(i + 1) * chunk_len))
154 }
155
156 /// Splits a mutable slice into equal chunks.
157 ///
158 /// ## Preconditions
159 ///
160 /// - the length of the slice must be a multiple of `chunk_len`
161 /// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
162 fn slice_chunks_mut<'a>(
163 data: Self::FSliceMut<'a>,
164 chunk_len: usize,
165 ) -> impl Iterator<Item = Self::FSliceMut<'a>>;
166
167 /// Splits an immutable slice of power-two length into two equal halves.
168 ///
169 /// Unlike all other splitting methods, this method does not require input or output slices
170 /// to be a multiple of [`Self::ALIGNMENT`].
171 fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
172 // This default implementation works only for slices with alignment of 1.
173 assert_eq!(Self::ALIGNMENT, 1);
174
175 assert!(
176 data.len().is_power_of_two() && data.len() > 1,
177 "data length must be a power of two greater than 1"
178 );
179 let mid = data.len() / 2;
180 Self::split_at(data, mid)
181 }
182
183 /// Splits a mutable slice of power-two length into two equal halves.
184 ///
185 /// Unlike all other splitting methods, this method does not require input or output slices
186 /// to be a multiple of [`Self::ALIGNMENT`].
187 fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
188 // This default implementation works only for slices with alignment of 1.
189 assert_eq!(Self::ALIGNMENT, 1);
190
191 assert!(
192 data.len().is_power_of_two() && data.len() > 1,
193 "data length must be a power of two greater than 1"
194 );
195 let mid = data.len() / 2;
196 Self::split_at_mut(data, mid)
197 }
198
199 /// Slices off the first `n` elements, mutably, where `n` is a power of two.
200 ///
201 /// # Arguments
202 ///
203 /// * `input` - The input slice to be trimmed
204 /// * `n` - The length of the requested subslice
205 ///
206 /// # Pre-conditions
207 ///
208 /// * `n` is a power of two
209 ///
210 /// # Returns
211 ///
212 /// A mutable slice with the first `n` elements. If the input is smaller than `n`, this returns
213 /// the whole input.
214 fn slice_power_of_two_mut<'a>(
215 input: &'a mut Self::FSliceMut<'_>,
216 n: usize,
217 ) -> Self::FSliceMut<'a> {
218 if input.len() <= n {
219 return Self::to_owned_mut(input);
220 }
221
222 let mut result = if n > Self::ALIGNMENT {
223 Self::slice_mut(input, ..n)
224 } else {
225 Self::to_owned_mut(input)
226 };
227
228 // If n is smaller than the input len, narrow down the slice further
229 for _ in checked_log_2(n)..checked_log_2(result.len()) {
230 (result, _) = Self::split_half_mut(result);
231 }
232 result
233 }
234}
235
236/// `SubfieldSlice` is a structure that represents a slice of elements stored in a compute memory,
237/// along with an associated tower level. This structure is used to handle subfield operations
238/// within a computational context, where the `slice` is an immutable reference to the data
239/// and `tower_level` indicates the level of the field tower to which the elements belong.
240///
241/// # Type Parameters
242/// - `'a`: The lifetime of the slice reference.
243/// - `F`: The type of the field elements stored in the slice.
244/// - `Mem`: A type that implements the `ComputeMemory` trait, which provides the necessary
245/// operations for handling memory slices.
246///
247/// # Fields
248/// - `slice`: An immutable slice of elements stored in compute memory, represented by
249/// `Mem::FSlice<'a>`.
250/// - `tower_level`: A `usize` value indicating the level of the field tower for the elements in the
251/// slice.
252///
253/// # Usage
254/// `SubfieldSlice` is typically used in scenarios where operations need to be performed on
255/// specific subfields of a larger field structure, allowing for efficient computation and
256/// manipulation of data within a hierarchical field system.
257pub struct SubfieldSlice<'a, F, Mem: ComputeMemory<F>> {
258 pub slice: Mem::FSlice<'a>,
259 pub tower_level: usize,
260}
261
262impl<'a, F, Mem: ComputeMemory<F>> SubfieldSlice<'a, F, Mem> {
263 pub fn new(slice: Mem::FSlice<'a>, tower_level: usize) -> Self {
264 Self { slice, tower_level }
265 }
266
267 /// Returns the length of the slice in terms of the number of subfield elements it contains.
268 pub fn len(&self) -> usize
269 where
270 F: TowerField,
271 {
272 self.slice.len() << (F::TOWER_LEVEL - self.tower_level)
273 }
274
275 pub fn is_empty(&self) -> bool
276 where
277 F: TowerField,
278 {
279 self.slice.is_empty()
280 }
281}
282
283/// `SubfieldSliceMut` represents a mutable slice of field elements with identical semantics to
284/// `SubfieldSlice`.
285pub struct SubfieldSliceMut<'a, F, Mem: ComputeMemory<F>> {
286 pub slice: Mem::FSliceMut<'a>,
287 pub tower_level: usize,
288}
289
290impl<'a, F, Mem: ComputeMemory<F>> SubfieldSliceMut<'a, F, Mem> {
291 pub fn new(slice: Mem::FSliceMut<'a>, tower_level: usize) -> Self {
292 Self { slice, tower_level }
293 }
294}