binius_compute/
memory.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::ops::RangeBounds;
4
5use binius_field::TowerField;
6use binius_utils::checked_arithmetics::checked_int_div;
7
8pub trait SizedSlice {
9	fn is_empty(&self) -> bool {
10		self.len() == 0
11	}
12
13	fn len(&self) -> usize;
14}
15
16impl<T> SizedSlice for &[T] {
17	fn len(&self) -> usize {
18		(**self).len()
19	}
20}
21
22impl<T> SizedSlice for &mut [T] {
23	fn len(&self) -> usize {
24		(**self).len()
25	}
26}
27
28/// A batch of slices of the same length.
29pub struct SlicesBatch<Slice: SizedSlice> {
30	rows: Vec<Slice>,
31	row_len: usize,
32}
33
34impl<Slice: SizedSlice> SlicesBatch<Slice> {
35	/// Creates a new batch of slices with the given length.
36	///
37	/// # Panics
38	/// If any of the slices in `rows` does not have the specified `row_len`.
39	pub fn new(rows: Vec<Slice>, row_len: usize) -> Self {
40		for row in &rows {
41			assert_eq!(row.len(), row_len);
42		}
43
44		Self { rows, row_len }
45	}
46
47	/// Number of memory slices
48	pub fn n_rows(&self) -> usize {
49		self.rows.len()
50	}
51
52	/// Length of each memory slice
53	pub fn row_len(&self) -> usize {
54		self.row_len
55	}
56
57	/// Returns a slice of the batch at the given index.
58	pub fn row(&self, index: usize) -> &Slice {
59		&self.rows[index]
60	}
61
62	/// Returns iterator over the slices in the batch.
63	pub fn iter(&self) -> impl Iterator<Item = &Slice> {
64		self.rows.iter()
65	}
66}
67
68/// Interface for manipulating handles to memory in a compute device.
69pub trait ComputeMemory<F> {
70	/// The required alignment of indices for the split methods. This must be a power of two.
71	const ALIGNMENT: usize;
72
73	/// An opaque handle to an immutable slice of elements stored in a compute memory.
74	type FSlice<'a>: Copy + SizedSlice + Sync;
75
76	/// An opaque handle to a mutable slice of elements stored in a compute memory.
77	type FSliceMut<'a>: SizedSlice;
78
79	/// Borrows an immutable memory slice, narrowing the lifetime.
80	fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a>;
81
82	/// Borrows a mutable memory slice, narrowing the lifetime.
83	fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a>;
84
85	// Converts a reference to an FSliceMut to an FSliceMut
86	fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a>;
87
88	/// Borrows a mutable memory slice as immutable.
89	///
90	/// This allows the immutable reference to be copied.
91	fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a>;
92
93	/// Borrows a subslice of an immutable memory slice.
94	///
95	/// ## Preconditions
96	///
97	/// - the range bounds must be multiples of [`Self::ALIGNMENT`]
98	fn slice(data: Self::FSlice<'_>, range: impl RangeBounds<usize>) -> Self::FSlice<'_>;
99
100	/// Borrows a subslice of a mutable memory slice.
101	///
102	/// ## Preconditions
103	///
104	/// - the range bounds must be multiples of [`Self::ALIGNMENT`]
105	fn slice_mut<'a>(
106		data: &'a mut Self::FSliceMut<'_>,
107		range: impl RangeBounds<usize>,
108	) -> Self::FSliceMut<'a>;
109
110	/// Splits an immutable slice into two disjoint subslices.
111	///
112	/// ## Preconditions
113	///
114	/// - `mid` must be a multiple of [`Self::ALIGNMENT`]
115	fn split_at(data: Self::FSlice<'_>, mid: usize) -> (Self::FSlice<'_>, Self::FSlice<'_>) {
116		let head = Self::slice(data, ..mid);
117		let tail = Self::slice(data, mid..);
118		(head, tail)
119	}
120
121	/// Splits a mutable slice into two disjoint subslices.
122	///
123	/// ## Preconditions
124	///
125	/// - `mid` must be a multiple of [`Self::ALIGNMENT`]
126	fn split_at_mut(
127		data: Self::FSliceMut<'_>,
128		mid: usize,
129	) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>);
130
131	fn split_at_mut_borrowed<'a>(
132		data: &'a mut Self::FSliceMut<'_>,
133		mid: usize,
134	) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
135		let borrowed = Self::slice_mut(data, ..);
136		Self::split_at_mut(borrowed, mid)
137	}
138
139	/// Splits slice into equal chunks.
140	///
141	/// ## Preconditions
142	///
143	/// - the length of the slice must be a multiple of `chunk_len`
144	/// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
145	fn slice_chunks<'a>(
146		data: Self::FSlice<'a>,
147		chunk_len: usize,
148	) -> impl Iterator<Item = Self::FSlice<'a>> {
149		let n_chunks = checked_int_div(data.len(), chunk_len);
150		(0..n_chunks).map(move |i| Self::slice(data, i * chunk_len..(i + 1) * chunk_len))
151	}
152
153	/// Splits a mutable slice into equal chunks.
154	///
155	/// ## Preconditions
156	///
157	/// - the length of the slice must be a multiple of `chunk_len`
158	/// - `chunk_len` must be a multiple of [`Self::ALIGNMENT`]
159	fn slice_chunks_mut<'a>(
160		data: Self::FSliceMut<'a>,
161		chunk_len: usize,
162	) -> impl Iterator<Item = Self::FSliceMut<'a>>;
163
164	/// Splits an immutable slice of power-two length into two equal halves.
165	///
166	/// Unlike all other splitting methods, this method does not require input or output slices
167	/// to be a multiple of [`Self::ALIGNMENT`].
168	fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
169		// This default implementation works only for slices with alignment of 1.
170		assert_eq!(Self::ALIGNMENT, 1);
171
172		assert!(
173			data.len().is_power_of_two() && data.len() > 1,
174			"data length must be a power of two greater than 1"
175		);
176		let mid = data.len() / 2;
177		Self::split_at(data, mid)
178	}
179
180	/// Splits a mutable slice of power-two length into two equal halves.
181	///
182	/// Unlike all other splitting methods, this method does not require input or output slices
183	/// to be a multiple of [`Self::ALIGNMENT`].
184	fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
185		// This default implementation works only for slices with alignment of 1.
186		assert_eq!(Self::ALIGNMENT, 1);
187
188		assert!(
189			data.len().is_power_of_two() && data.len() > 1,
190			"data length must be a power of two greater than 1"
191		);
192		let mid = data.len() / 2;
193		Self::split_at_mut(data, mid)
194	}
195}
196
197/// `SubfieldSlice` is a structure that represents a slice of elements stored in a compute memory,
198/// along with an associated tower level. This structure is used to handle subfield operations
199/// within a computational context, where the `slice` is an immutable reference to the data
200/// and `tower_level` indicates the level of the field tower to which the elements belong.
201///
202/// # Type Parameters
203/// - `'a`: The lifetime of the slice reference.
204/// - `F`: The type of the field elements stored in the slice.
205/// - `Mem`: A type that implements the `ComputeMemory` trait, which provides the necessary
206///   operations for handling memory slices.
207///
208/// # Fields
209/// - `slice`: An immutable slice of elements stored in compute memory, represented by
210///   `Mem::FSlice<'a>`.
211/// - `tower_level`: A `usize` value indicating the level of the field tower for the elements in the
212///   slice.
213///
214/// # Usage
215/// `SubfieldSlice` is typically used in scenarios where operations need to be performed on
216/// specific subfields of a larger field structure, allowing for efficient computation and
217/// manipulation of data within a hierarchical field system.
218pub struct SubfieldSlice<'a, F, Mem: ComputeMemory<F>> {
219	pub slice: Mem::FSlice<'a>,
220	pub tower_level: usize,
221}
222
223impl<'a, F, Mem: ComputeMemory<F>> SubfieldSlice<'a, F, Mem> {
224	pub fn new(slice: Mem::FSlice<'a>, tower_level: usize) -> Self {
225		Self { slice, tower_level }
226	}
227
228	/// Returns the length of the slice in terms of the number of subfield elements it contains.
229	pub fn len(&self) -> usize
230	where
231		F: TowerField,
232	{
233		self.slice.len() << (F::TOWER_LEVEL - self.tower_level)
234	}
235
236	pub fn is_empty(&self) -> bool
237	where
238		F: TowerField,
239	{
240		self.slice.is_empty()
241	}
242}
243
244/// `SubfieldSliceMut` represents a mutable slice of field elements with identical semantics to
245/// `SubfieldSlice`.
246pub struct SubfieldSliceMut<'a, F, Mem: ComputeMemory<F>> {
247	pub slice: Mem::FSliceMut<'a>,
248	pub tower_level: usize,
249}
250
251impl<'a, F, Mem: ComputeMemory<F>> SubfieldSliceMut<'a, F, Mem> {
252	pub fn new(slice: Mem::FSliceMut<'a>, tower_level: usize) -> Self {
253		Self { slice, tower_level }
254	}
255}