binius_compute/
memory.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::RangeBounds};
4
5use binius_field::TowerField;
6use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
7
8pub trait SizedSlice {
9	fn is_empty(&self) -> bool {
10		self.len() == 0
11	}
12
13	fn len(&self) -> usize;
14}
15
16impl<T> SizedSlice for &[T] {
17	fn len(&self) -> usize {
18		(**self).len()
19	}
20}
21
22impl<T> SizedSlice for &mut [T] {
23	fn len(&self) -> usize {
24		(**self).len()
25	}
26}
27
28/// A batch of slices of the same length.
29pub struct SlicesBatch<Slice: SizedSlice> {
30	rows: Vec<Slice>,
31	row_len: usize,
32}
33
34impl<Slice: SizedSlice> SlicesBatch<Slice> {
35	/// Creates a new batch of slices with the given length.
36	///
37	/// # Panics
38	/// If any of the slices in `rows` does not have the specified `row_len`.
39	pub fn new(rows: Vec<Slice>, row_len: usize) -> Self {
40		for row in &rows {
41			assert_eq!(row.len(), row_len);
42		}
43
44		Self { rows, row_len }
45	}
46
47	/// Number of memory slices
48	pub fn n_rows(&self) -> usize {
49		self.rows.len()
50	}
51
52	/// Length of each memory slice
53	pub fn row_len(&self) -> usize {
54		self.row_len
55	}
56
57	/// Returns a slice of the batch at the given index.
58	pub fn row(&self, index: usize) -> &Slice {
59		&self.rows[index]
60	}
61
62	/// Returns iterator over the slices in the batch.
63	pub fn iter(&self) -> impl Iterator<Item = &Slice> {
64		self.rows.iter()
65	}
66}
67
68/// Interface for manipulating handles to memory in a compute device.
69pub trait ComputeMemory<F> {
70	/// The required alignment of indices for the split methods. This must be a power of two.
71	const ALIGNMENT: usize;
72
73	/// An opaque handle to an immutable slice of elements stored in a compute memory.
74	type FSlice<'a>: Copy + SizedSlice + Send + Sync + Debug;
75
76	/// An opaque handle to a mutable slice of elements stored in a compute memory.
77	type FSliceMut<'a>: SizedSlice + Send + Debug;
78
79	/// Borrows an immutable memory slice, narrowing the lifetime.
80	fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a>;
81
82	/// Borrows a mutable memory slice, narrowing the lifetime.
83	fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a>;
84
85	// Converts a reference to an FSliceMut to an FSliceMut
86	fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a>;
87
88	/// Borrows a mutable memory slice as immutable.
89	///
90	/// This allows the immutable reference to be copied.
91	fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a>;
92
93	/// Converts a mutable memory slice to an immutable slice.
94	fn to_const(data: Self::FSliceMut<'_>) -> Self::FSlice<'_>;
95
96	/// Borrows a subslice of an immutable memory slice.
97	///
98	/// ## Preconditions
99	///
100	/// - the range bounds must be multiples of [`Self::ALIGNMENT`]
101	fn slice(data: Self::FSlice<'_>, range: impl RangeBounds<usize>) -> Self::FSlice<'_>;
102
103	/// Borrows a subslice of a mutable memory slice.
104	///
105	/// ## Preconditions
106	///
107	/// - the range bounds must be multiples of [`Self::ALIGNMENT`]
108	fn slice_mut<'a>(
109		data: &'a mut Self::FSliceMut<'_>,
110		range: impl RangeBounds<usize>,
111	) -> Self::FSliceMut<'a>;
112
113	/// Splits an immutable slice into two disjoint subslices.
114	///
115	/// ## Preconditions
116	///
117	/// - `mid` must be a multiple of [`Self::ALIGNMENT`]
118	fn split_at(data: Self::FSlice<'_>, mid: usize) -> (Self::FSlice<'_>, Self::FSlice<'_>) {
119		let head = Self::slice(data, ..mid);
120		let tail = Self::slice(data, mid..);
121		(head, tail)
122	}
123
124	/// Splits a mutable slice into two disjoint subslices.
125	///
126	/// ## Preconditions
127	///
128	/// - `mid` must be a multiple of [`Self::ALIGNMENT`]
129	fn split_at_mut(
130		data: Self::FSliceMut<'_>,
131		mid: usize,
132	) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>);
133
134	fn split_at_mut_borrowed<'a>(
135		data: &'a mut Self::FSliceMut<'_>,
136		mid: usize,
137	) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
138		let borrowed = Self::slice_mut(data, ..);
139		Self::split_at_mut(borrowed, mid)
140	}
141
142	/// Splits slice into equal chunks.
143	///
144	/// ## Preconditions
145	///
146	/// - the length of the slice must be a multiple of `chunk_len`
147	/// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
148	fn slice_chunks<'a>(
149		data: Self::FSlice<'a>,
150		chunk_len: usize,
151	) -> impl Iterator<Item = Self::FSlice<'a>> {
152		let n_chunks = checked_int_div(data.len(), chunk_len);
153		(0..n_chunks).map(move |i| Self::slice(data, i * chunk_len..(i + 1) * chunk_len))
154	}
155
156	/// Splits a mutable slice into equal chunks.
157	///
158	/// ## Preconditions
159	///
160	/// - the length of the slice must be a multiple of `chunk_len`
161	/// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
162	fn slice_chunks_mut<'a>(
163		data: Self::FSliceMut<'a>,
164		chunk_len: usize,
165	) -> impl Iterator<Item = Self::FSliceMut<'a>>;
166
167	/// Splits an immutable slice of power-two length into two equal halves.
168	///
169	/// Unlike all other splitting methods, this method does not require input or output slices
170	/// to be a multiple of [`Self::ALIGNMENT`].
171	fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
172		// This default implementation works only for slices with alignment of 1.
173		assert_eq!(Self::ALIGNMENT, 1);
174
175		assert!(
176			data.len().is_power_of_two() && data.len() > 1,
177			"data length must be a power of two greater than 1"
178		);
179		let mid = data.len() / 2;
180		Self::split_at(data, mid)
181	}
182
183	/// Splits a mutable slice of power-two length into two equal halves.
184	///
185	/// Unlike all other splitting methods, this method does not require input or output slices
186	/// to be a multiple of [`Self::ALIGNMENT`].
187	fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
188		// This default implementation works only for slices with alignment of 1.
189		assert_eq!(Self::ALIGNMENT, 1);
190
191		assert!(
192			data.len().is_power_of_two() && data.len() > 1,
193			"data length must be a power of two greater than 1"
194		);
195		let mid = data.len() / 2;
196		Self::split_at_mut(data, mid)
197	}
198
199	/// Slices off the first `n` elements, mutably, where `n` is a power of two.
200	///
201	/// # Arguments
202	///
203	/// * `input` - The input slice to be trimmed
204	/// * `n` - The length of the requested subslice
205	///
206	/// # Pre-conditions
207	///
208	/// * `n` is a power of two
209	///
210	/// # Returns
211	///
212	/// A mutable slice with the first `n` elements. If the input is smaller than `n`, this returns
213	/// the whole input.
214	fn slice_power_of_two_mut<'a>(
215		input: &'a mut Self::FSliceMut<'_>,
216		n: usize,
217	) -> Self::FSliceMut<'a> {
218		if input.len() <= n {
219			return Self::to_owned_mut(input);
220		}
221
222		let mut result = if n > Self::ALIGNMENT {
223			Self::slice_mut(input, ..n)
224		} else {
225			Self::to_owned_mut(input)
226		};
227
228		// If n is smaller than the input len, narrow down the slice further
229		for _ in checked_log_2(n)..checked_log_2(result.len()) {
230			(result, _) = Self::split_half_mut(result);
231		}
232		result
233	}
234}
235
236/// `SubfieldSlice` is a structure that represents a slice of elements stored in a compute memory,
237/// along with an associated tower level. This structure is used to handle subfield operations
238/// within a computational context, where the `slice` is an immutable reference to the data
239/// and `tower_level` indicates the level of the field tower to which the elements belong.
240///
241/// # Type Parameters
242/// - `'a`: The lifetime of the slice reference.
243/// - `F`: The type of the field elements stored in the slice.
244/// - `Mem`: A type that implements the `ComputeMemory` trait, which provides the necessary
245///   operations for handling memory slices.
246///
247/// # Fields
248/// - `slice`: An immutable slice of elements stored in compute memory, represented by
249///   `Mem::FSlice<'a>`.
250/// - `tower_level`: A `usize` value indicating the level of the field tower for the elements in the
251///   slice.
252///
253/// # Usage
254/// `SubfieldSlice` is typically used in scenarios where operations need to be performed on
255/// specific subfields of a larger field structure, allowing for efficient computation and
256/// manipulation of data within a hierarchical field system.
257pub struct SubfieldSlice<'a, F, Mem: ComputeMemory<F>> {
258	pub slice: Mem::FSlice<'a>,
259	pub tower_level: usize,
260}
261
262impl<'a, F, Mem: ComputeMemory<F>> SubfieldSlice<'a, F, Mem> {
263	pub fn new(slice: Mem::FSlice<'a>, tower_level: usize) -> Self {
264		Self { slice, tower_level }
265	}
266
267	/// Returns the length of the slice in terms of the number of subfield elements it contains.
268	pub fn len(&self) -> usize
269	where
270		F: TowerField,
271	{
272		self.slice.len() << (F::TOWER_LEVEL - self.tower_level)
273	}
274
275	pub fn is_empty(&self) -> bool
276	where
277		F: TowerField,
278	{
279		self.slice.is_empty()
280	}
281}
282
283/// `SubfieldSliceMut` represents a mutable slice of field elements with identical semantics to
284/// `SubfieldSlice`.
285pub struct SubfieldSliceMut<'a, F, Mem: ComputeMemory<F>> {
286	pub slice: Mem::FSliceMut<'a>,
287	pub tower_level: usize,
288}
289
290impl<'a, F, Mem: ComputeMemory<F>> SubfieldSliceMut<'a, F, Mem> {
291	pub fn new(slice: Mem::FSliceMut<'a>, tower_level: usize) -> Self {
292		Self { slice, tower_level }
293	}
294}