binius_compute/memory.rs
1// Copyright 2025 Irreducible Inc.
2
3use std::ops::RangeBounds;
4
5use binius_field::TowerField;
6use binius_utils::checked_arithmetics::checked_int_div;
7
8pub trait SizedSlice {
9 fn is_empty(&self) -> bool {
10 self.len() == 0
11 }
12
13 fn len(&self) -> usize;
14}
15
16impl<T> SizedSlice for &[T] {
17 fn len(&self) -> usize {
18 (**self).len()
19 }
20}
21
22impl<T> SizedSlice for &mut [T] {
23 fn len(&self) -> usize {
24 (**self).len()
25 }
26}
27
28/// A batch of slices of the same length.
29pub struct SlicesBatch<Slice: SizedSlice> {
30 rows: Vec<Slice>,
31 row_len: usize,
32}
33
34impl<Slice: SizedSlice> SlicesBatch<Slice> {
35 /// Creates a new batch of slices with the given length.
36 ///
37 /// # Panics
38 /// If any of the slices in `rows` does not have the specified `row_len`.
39 pub fn new(rows: Vec<Slice>, row_len: usize) -> Self {
40 for row in &rows {
41 assert_eq!(row.len(), row_len);
42 }
43
44 Self { rows, row_len }
45 }
46
47 /// Number of memory slices
48 pub fn n_rows(&self) -> usize {
49 self.rows.len()
50 }
51
52 /// Length of each memory slice
53 pub fn row_len(&self) -> usize {
54 self.row_len
55 }
56
57 /// Returns a slice of the batch at the given index.
58 pub fn row(&self, index: usize) -> &Slice {
59 &self.rows[index]
60 }
61
62 /// Returns iterator over the slices in the batch.
63 pub fn iter(&self) -> impl Iterator<Item = &Slice> {
64 self.rows.iter()
65 }
66}
67
68/// Interface for manipulating handles to memory in a compute device.
69pub trait ComputeMemory<F> {
70 /// The required alignment of indices for the split methods. This must be a power of two.
71 const ALIGNMENT: usize;
72
73 /// An opaque handle to an immutable slice of elements stored in a compute memory.
74 type FSlice<'a>: Copy + SizedSlice + Sync;
75
76 /// An opaque handle to a mutable slice of elements stored in a compute memory.
77 type FSliceMut<'a>: SizedSlice;
78
79 /// Borrows an immutable memory slice, narrowing the lifetime.
80 fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a>;
81
82 /// Borrows a mutable memory slice, narrowing the lifetime.
83 fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a>;
84
85 // Converts a reference to an FSliceMut to an FSliceMut
86 fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a>;
87
88 /// Borrows a mutable memory slice as immutable.
89 ///
90 /// This allows the immutable reference to be copied.
91 fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a>;
92
93 /// Borrows a subslice of an immutable memory slice.
94 ///
95 /// ## Preconditions
96 ///
97 /// - the range bounds must be multiples of [`Self::ALIGNMENT`]
98 fn slice(data: Self::FSlice<'_>, range: impl RangeBounds<usize>) -> Self::FSlice<'_>;
99
100 /// Borrows a subslice of a mutable memory slice.
101 ///
102 /// ## Preconditions
103 ///
104 /// - the range bounds must be multiples of [`Self::ALIGNMENT`]
105 fn slice_mut<'a>(
106 data: &'a mut Self::FSliceMut<'_>,
107 range: impl RangeBounds<usize>,
108 ) -> Self::FSliceMut<'a>;
109
110 /// Splits an immutable slice into two disjoint subslices.
111 ///
112 /// ## Preconditions
113 ///
114 /// - `mid` must be a multiple of [`Self::ALIGNMENT`]
115 fn split_at(data: Self::FSlice<'_>, mid: usize) -> (Self::FSlice<'_>, Self::FSlice<'_>) {
116 let head = Self::slice(data, ..mid);
117 let tail = Self::slice(data, mid..);
118 (head, tail)
119 }
120
121 /// Splits a mutable slice into two disjoint subslices.
122 ///
123 /// ## Preconditions
124 ///
125 /// - `mid` must be a multiple of [`Self::ALIGNMENT`]
126 fn split_at_mut(
127 data: Self::FSliceMut<'_>,
128 mid: usize,
129 ) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>);
130
131 fn split_at_mut_borrowed<'a>(
132 data: &'a mut Self::FSliceMut<'_>,
133 mid: usize,
134 ) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
135 let borrowed = Self::slice_mut(data, ..);
136 Self::split_at_mut(borrowed, mid)
137 }
138
139 /// Splits slice into equal chunks.
140 ///
141 /// ## Preconditions
142 ///
143 /// - the length of the slice must be a multiple of `chunk_len`
144 /// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
145 fn slice_chunks<'a>(
146 data: Self::FSlice<'a>,
147 chunk_len: usize,
148 ) -> impl Iterator<Item = Self::FSlice<'a>> {
149 let n_chunks = checked_int_div(data.len(), chunk_len);
150 (0..n_chunks).map(move |i| Self::slice(data, i * chunk_len..(i + 1) * chunk_len))
151 }
152
153 /// Splits a mutable slice into equal chunks.
154 ///
155 /// ## Preconditions
156 ///
157 /// - the length of the slice must be a multiple of `chunk_len`
158 /// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
159 fn slice_chunks_mut<'a>(
160 data: Self::FSliceMut<'a>,
161 chunk_len: usize,
162 ) -> impl Iterator<Item = Self::FSliceMut<'a>>;
163
164 /// Splits an immutable slice of power-two length into two equal halves.
165 ///
166 /// Unlike all other splitting methods, this method does not require input or output slices
167 /// to be a multiple of [`Self::ALIGNMENT`].
168 fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
169 // This default implementation works only for slices with alignment of 1.
170 assert_eq!(Self::ALIGNMENT, 1);
171
172 assert!(
173 data.len().is_power_of_two() && data.len() > 1,
174 "data length must be a power of two greater than 1"
175 );
176 let mid = data.len() / 2;
177 Self::split_at(data, mid)
178 }
179
180 /// Splits a mutable slice of power-two length into two equal halves.
181 ///
182 /// Unlike all other splitting methods, this method does not require input or output slices
183 /// to be a multiple of [`Self::ALIGNMENT`].
184 fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
185 // This default implementation works only for slices with alignment of 1.
186 assert_eq!(Self::ALIGNMENT, 1);
187
188 assert!(
189 data.len().is_power_of_two() && data.len() > 1,
190 "data length must be a power of two greater than 1"
191 );
192 let mid = data.len() / 2;
193 Self::split_at_mut(data, mid)
194 }
195}
196
197/// `SubfieldSlice` is a structure that represents a slice of elements stored in a compute memory,
198/// along with an associated tower level. This structure is used to handle subfield operations
199/// within a computational context, where the `slice` is an immutable reference to the data
200/// and `tower_level` indicates the level of the field tower to which the elements belong.
201///
202/// # Type Parameters
203/// - `'a`: The lifetime of the slice reference.
204/// - `F`: The type of the field elements stored in the slice.
205/// - `Mem`: A type that implements the `ComputeMemory` trait, which provides the necessary
206/// operations for handling memory slices.
207///
208/// # Fields
209/// - `slice`: An immutable slice of elements stored in compute memory, represented by
210/// `Mem::FSlice<'a>`.
211/// - `tower_level`: A `usize` value indicating the level of the field tower for the elements in the
212/// slice.
213///
214/// # Usage
215/// `SubfieldSlice` is typically used in scenarios where operations need to be performed on
216/// specific subfields of a larger field structure, allowing for efficient computation and
217/// manipulation of data within a hierarchical field system.
218pub struct SubfieldSlice<'a, F, Mem: ComputeMemory<F>> {
219 pub slice: Mem::FSlice<'a>,
220 pub tower_level: usize,
221}
222
223impl<'a, F, Mem: ComputeMemory<F>> SubfieldSlice<'a, F, Mem> {
224 pub fn new(slice: Mem::FSlice<'a>, tower_level: usize) -> Self {
225 Self { slice, tower_level }
226 }
227
228 /// Returns the length of the slice in terms of the number of subfield elements it contains.
229 pub fn len(&self) -> usize
230 where
231 F: TowerField,
232 {
233 self.slice.len() << (F::TOWER_LEVEL - self.tower_level)
234 }
235
236 pub fn is_empty(&self) -> bool
237 where
238 F: TowerField,
239 {
240 self.slice.is_empty()
241 }
242}
243
244/// `SubfieldSliceMut` represents a mutable slice of field elements with identical semantics to
245/// `SubfieldSlice`.
246pub struct SubfieldSliceMut<'a, F, Mem: ComputeMemory<F>> {
247 pub slice: Mem::FSliceMut<'a>,
248 pub tower_level: usize,
249}
250
251impl<'a, F, Mem: ComputeMemory<F>> SubfieldSliceMut<'a, F, Mem> {
252 pub fn new(slice: Mem::FSliceMut<'a>, tower_level: usize) -> Self {
253 Self { slice, tower_level }
254 }
255}