binius_compute/
layer.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::ops::Range;
4
5use binius_field::{BinaryField, ExtensionField, Field};
6use binius_math::ArithCircuit;
7use binius_ntt::AdditiveNTT;
8use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
9use itertools::Either;
10
11use super::{
12	alloc::Error as AllocError,
13	memory::{ComputeMemory, SubfieldSlice},
14};
15use crate::memory::{SizedSlice, SlicesBatch};
16
17/// A hardware abstraction layer (HAL) for compute operations.
18pub trait ComputeLayer<F: Field>: 'static + Sync {
19	/// The device memory.
20	type DevMem: ComputeMemory<F>;
21
22	/// The executor that can execute operations on the device.
23	type Exec;
24
25	/// The executor that can execute operations on a kernel-level granularity (i.e., a single
26	/// core).
27	type KernelExec: KernelExecutor<F, Mem = Self::DevMem, ExprEval = Self::ExprEval>;
28
29	/// The operation (scalar) value type.
30	type OpValue;
31
32	/// The evaluator for arithmetic expressions (polynomials).
33	type ExprEval: Sync;
34
35	/// Allocates a slice of memory on the host that is prepared for transfers to/from the device.
36	///
37	/// Depending on the compute layer, this may perform steps beyond just allocating memory. For
38	/// example, it may allocate huge pages or map the allocated memory to the IOMMU.
39	///
40	/// The returned buffer is lifetime bound to the compute layer, allowing return types to have
41	/// drop methods referencing data in the compute layer.
42	fn host_alloc(&self, n: usize) -> impl AsMut<[F]> + '_;
43
44	/// Copy data from the host to the device.
45	///
46	/// ## Preconditions
47	///
48	/// * `src` and `dst` must have the same length.
49	/// * `src` must be a slice of a buffer returned by [`Self::host_alloc`].
50	fn copy_h2d(&self, src: &[F], dst: &mut FSliceMut<'_, F, Self>) -> Result<(), Error>;
51
52	/// Copy data from the device to the host.
53	///
54	/// ## Preconditions
55	///
56	/// * `src` and `dst` must have the same length.
57	/// * `dst` must be a slice of a buffer returned by [`Self::host_alloc`].
58	fn copy_d2h(&self, src: FSlice<'_, F, Self>, dst: &mut [F]) -> Result<(), Error>;
59
60	/// Copy data between disjoint device buffers.
61	///
62	/// ## Preconditions
63	///
64	/// * `src` and `dst` must have the same length.
65	fn copy_d2d(
66		&self,
67		src: FSlice<'_, F, Self>,
68		dst: &mut FSliceMut<'_, F, Self>,
69	) -> Result<(), Error>;
70
71	/// Executes an operation.
72	///
73	/// A HAL operation is an abstract function that runs with an executor reference.
74	fn execute(
75		&self,
76		f: impl FnOnce(&mut Self::Exec) -> Result<Vec<Self::OpValue>, Error>,
77	) -> Result<Vec<F>, Error>;
78
79	/// Creates an operation that depends on the concurrent execution of two inner operations.
80	fn join<Out1, Out2>(
81		&self,
82		exec: &mut Self::Exec,
83		op1: impl FnOnce(&mut Self::Exec) -> Result<Out1, Error>,
84		op2: impl FnOnce(&mut Self::Exec) -> Result<Out2, Error>,
85	) -> Result<(Out1, Out2), Error> {
86		let out1 = op1(exec)?;
87		let out2 = op2(exec)?;
88		Ok((out1, out2))
89	}
90
91	/// Creates an operation that depends on the concurrent execution of a sequence of operations.
92	fn map<Out, I: ExactSizeIterator>(
93		&self,
94		exec: &mut Self::Exec,
95		iter: I,
96		map: impl Fn(&mut Self::Exec, I::Item) -> Result<Out, Error>,
97	) -> Result<Vec<Out>, Error> {
98		iter.map(|item| map(exec, item)).collect()
99	}
100
101	/// Compiles an arithmetic expression to the evaluator.
102	fn compile_expr(&self, expr: &ArithCircuit<F>) -> Result<Self::ExprEval, Error>;
103
104	/// Launch many kernels in parallel and accumulate the scalar results with field addition.
105	///
106	/// This method provides low-level access to schedule parallel kernel executions on the compute
107	/// platform. A _kernel_ is a program that executes synchronously in one thread, with access to
108	/// local memory buffers. When the environment launches a kernel, it sets up the kernel's local
109	/// memory according to the memory mapping specifications provided by the `mem_maps` parameter.
110	/// The mapped buffers have type [`KernelBuffer`], and they may be read-write or read-only.
111	/// When the kernel exits, it returns a small number of computed values as field elements. The
112	/// vector of returned scalars is accumulated via binary field addition across all kernels and
113	/// returned from the call.
114	///
115	/// This method is fairly general but also designed to fit the specific needs of the sumcheck
116	/// protocol. That motivates the choice that returned values are small-length vectors that are
117	/// accumulated with field addition.
118	///
119	/// ## Buffer chunking
120	///
121	/// The kernel local memory buffers are thought of as slices of a larger buffer, which may or
122	/// may not exist somewhere else. Each kernel operates on a chunk of the larger buffers. For
123	/// example, the [`KernelMemMap::Chunked`] mapping specifies that each kernel operates on a
124	/// read-only chunk of a buffer in global device memory. The [`KernelMemMap::Local`] mapping
125	/// specifies that a kernel operates on a local scratchpad initialized with zeros and discarded
126	/// at the end of kernel execution (sort of like /dev/null).
127	///
128	/// This [`ComputeLayer`] object can decide how many kernels to launch and thus how large
129	/// each kernel's buffer chunks are. The number of chunks must be a power of two. This
130	/// information is provided to the kernel specification closure as an argument.
131	///
132	/// ## Kernel specification
133	///
134	/// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
135	/// has three parameters:
136	///
137	/// * `kernel_exec` - the kernel execution environment.
138	/// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
139	/// * `buffers` - a vector of kernel-local buffers.
140	///
141	/// The closure must respect certain assumptions:
142	///
143	/// * The kernel closure control flow is identical on each invocation when `log_chunks` is
144	///   unchanged.
145	///
146	/// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
147	/// for example with different values for `log_chunks`.
148	///
149	/// ## Arguments
150	///
151	/// * `map` - the kernel specification closure. See the "Kernel specification" section above.
152	/// * `mem_maps` - the memory mappings for the kernel-local buffers.
153	fn accumulate_kernels(
154		&self,
155		exec: &mut Self::Exec,
156		map: impl Sync
157		+ for<'a> Fn(
158			&'a mut Self::KernelExec,
159			usize,
160			Vec<KernelBuffer<'a, F, Self::DevMem>>,
161		) -> Result<Vec<<Self::KernelExec as KernelExecutor<F>>::Value>, Error>,
162		mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
163	) -> Result<Vec<Self::OpValue>, Error>;
164
165	/// Returns the inner product of a vector of subfield elements with big field elements.
166	///
167	/// ## Arguments
168	///
169	/// * `a_in` - the first input slice of subfield elements.
170	/// * `b_in` - the second input slice of `F` elements.
171	///
172	/// ## Throws
173	///
174	/// * if `tower_level` or `a_in` is greater than `F::TOWER_LEVEL`
175	/// * unless `a_in` and `b_in` contain the same number of elements, and the number is a power of
176	///   two
177	///
178	/// ## Returns
179	///
180	/// Returns the inner product of `a_in` and `b_in`.
181	fn inner_product(
182		&self,
183		exec: &mut Self::Exec,
184		a_in: SubfieldSlice<'_, F, Self::DevMem>,
185		b_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
186	) -> Result<Self::OpValue, Error>;
187
188	/// Computes the iterative tensor product of the input with the given coordinates.
189	///
190	/// This operation modifies the data buffer in place.
191	///
192	/// ## Mathematical Definition
193	///
194	/// This operation accepts parameters
195	///
196	/// * $n \in \mathbb{N}$ (`log_n`),
197	/// * $k \in \mathbb{N}$ (`coordinates.len()`),
198	/// * $v \in L^{2^n}$ (`data[..1 << log_n]`),
199	/// * $r \in L^k$ (`coordinates`),
200	///
201	/// and computes the vector
202	///
203	/// $$
204	/// v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})
205	/// $$
206	///
207	/// ## Throws
208	///
209	/// * unless `2**(log_n + coordinates.len())` equals `data.len()`
210	fn tensor_expand(
211		&self,
212		exec: &mut Self::Exec,
213		log_n: usize,
214		coordinates: &[F],
215		data: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
216	) -> Result<(), Error>;
217
218	/// Computes left matrix-vector multiplication of a subfield matrix with a big field vector.
219	///
220	/// ## Mathematical Definition
221	///
222	/// This operation accepts
223	///
224	/// * $n \in \mathbb{N}$ (`out.len()`),
225	/// * $m \in \mathbb{N}$ (`vec.len()`),
226	/// * $M \in K^{n \times m}$ (`mat`),
227	/// * $v \in K^m$ (`vec`),
228	///
229	/// and computes the vector $Mv$.
230	///
231	/// ## Args
232	///
233	/// * `mat` - a slice of elements from a subfield of `F`.
234	/// * `vec` - a slice of `F` elements.
235	/// * `out` - a buffer for the output vector of `F` elements.
236	///
237	/// ## Throws
238	///
239	/// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
240	/// * Returns an error if `mat` is not a subfield of `F`.
241	fn fold_left<'a>(
242		&'a self,
243		exec: &'a mut Self::Exec,
244		mat: SubfieldSlice<'_, F, Self::DevMem>,
245		vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
246		out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
247	) -> Result<(), Error>;
248
249	/// Computes right matrix-vector multiplication of a subfield matrix with a big field vector.
250	///
251	/// ## Mathematical Definition
252	///
253	/// This operation accepts
254	///
255	/// * $n \in \mathbb{N}$ (`vec.len()`),
256	/// * $m \in \mathbb{N}$ (`out.len()`),
257	/// * $M \in K^{n \times m}$ (`mat`),
258	/// * $v \in K^m$ (`vec`),
259	///
260	/// and computes the vector $((v')M)'$. The prime denotes a transpose
261	///
262	/// ## Args
263	///
264	/// * `mat` - a slice of elements from a subfield of `F`.
265	/// * `vec` - a slice of `F` elements.
266	/// * `out` - a buffer for the output vector of `F` elements.
267	///
268	/// ## Throws
269	///
270	/// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
271	/// * Returns an error if `mat` is not a subfield of `F`.
272	fn fold_right<'a>(
273		&'a self,
274		exec: &'a mut Self::Exec,
275		mat: SubfieldSlice<'_, F, Self::DevMem>,
276		vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
277		out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
278	) -> Result<(), Error>;
279
280	/// FRI-fold the interleaved codeword using the given challenges.
281	///
282	/// The FRI-fold operation folds a length $2^{n+b+\eta}$ vector of field elements into a length
283	/// $2^n$ vector of field elements. $n$ is the log block length of the code, $b$ is the log
284	/// batch size, and $b + \eta$ is the number of challenge elements. The operation has the
285	/// following mathematical structure:
286	///
287	/// 1. Split the challenge vector into two parts: $c_0$ with length $b$ and $c_1$ with length
288	///    $\eta$.
289	/// 2. Low fold the input data with the tensor expansion of $c_0.
290	/// 3. Apply $\eta$ layers of the inverse additive NTT to the data.
291	/// 4. Low fold the input data with the tensor expansion of $c_1.
292	///
293	/// The algorithm to perform steps 3 and 4 can be combined into a linear amount of work,
294	/// whereas step 3 on its own would require $\eta$ independent linear passes.
295	///
296	/// See [DP24], Section 4.2 for more details.
297	///
298	/// This operation writes the result out-of-place into an output buffer.
299	///
300	/// ## Arguments
301	///
302	/// * `ntt` - the NTT instance, used to look up the twiddle values.
303	/// * `log_len` - $n + \eta$, the binary logarithm of the code length.
304	/// * `log_batch_size` - $b$, the binary logarithm of the interleaved code batch size.
305	/// * `challenges` - the folding challenges, with length $b + \eta$.
306	/// * `data_in` - an input vector, with length $2^{n + b + \eta}$.
307	/// * `data_out` - an output buffer, with length $2^n$.
308	///
309	/// [DP24]: <https://eprint.iacr.org/2024/504>
310	#[allow(clippy::too_many_arguments)]
311	fn fri_fold<FSub>(
312		&self,
313		exec: &mut Self::Exec,
314		ntt: &(impl AdditiveNTT<FSub> + Sync),
315		log_len: usize,
316		log_batch_size: usize,
317		challenges: &[F],
318		data_in: FSlice<F, Self>,
319		data_out: &mut FSliceMut<F, Self>,
320	) -> Result<(), Error>
321	where
322		FSub: BinaryField,
323		F: ExtensionField<FSub>;
324
325	/// Extrapolates a line between a vector of evaluations at 0 and evaluations at 1.
326	///
327	/// Given two values $y_0, y_1$, this operation computes the value $y_z = y_0 + (y_1 - y_0) z$,
328	/// which is the value of the line that interpolates $(0, y_0), (1, y_1)$ at $z$. This computes
329	/// this operation in parallel over two vectors of big field elements of equal sizes.
330	///
331	/// The operation writes the result back in-place into the `evals_0` buffer.
332	///
333	/// ## Args
334	///
335	/// * `evals_0` - this is both an input and output buffer. As in input, it is populated with the
336	///   values $y_0$, which are the line's values at 0.
337	/// * `evals_1` - an input buffer with the values $y_1$, which are the line's values at 1.
338	/// * `z` - the scalar evaluation point.
339	///
340	/// ## Throws
341	///
342	/// * if `evals_0` and `evals_1` are not equal sizes.
343	/// * if the sizes of `evals_0` and `evals_1` are not powers of two.
344	fn extrapolate_line(
345		&self,
346		exec: &mut Self::Exec,
347		evals_0: &mut FSliceMut<F, Self>,
348		evals_1: FSlice<F, Self>,
349		z: F,
350	) -> Result<(), Error>;
351
352	/// Computes the elementwise application of a compiled arithmetic expression to multiple input
353	/// slices.
354	///
355	/// This operation applies the composition expression to each row of input values, where a row
356	/// consists of one element from each input slice at the same index position. The results are
357	/// stored in the output slice.
358	///
359	/// ## Mathematical Definition
360	///
361	/// Given:
362	/// - Multiple input slices $P_0, \ldots, P_{m-1}$, each of length $2^n$ elements
363	/// - A composition function $C(X_0, \ldots, X_{m-1})$
364	/// - An output slice $P_{\text{out}}$ of length $2^n$ elements
365	///
366	/// This operation computes:
367	///
368	/// $$
369	/// P_{\text{out}}\[i\] = C(P_0\[i\], \ldots, P_{m-1}\[i\])
370	/// \quad \forall i \in \{0, \ldots, 2^n- 1\}
371	/// $$
372	///
373	/// ## Arguments
374	///
375	/// * `exec` - The execution environment.
376	/// * `inputs` - A slice of input slices, where each slice contains field elements.
377	/// * `output` - A mutable output slice where the results will be stored.
378	/// * `composition` - The compiled arithmetic expression to apply.
379	///
380	/// ## Throws
381	///
382	/// * Returns an error if any input or output slice has a length that is not a power of two.
383	/// * Returns an error if the input and output slices do not all have the same length.
384	fn compute_composite(
385		&self,
386		exec: &mut Self::Exec,
387		inputs: &SlicesBatch<FSlice<'_, F, Self>>,
388		output: &mut FSliceMut<'_, F, Self>,
389		composition: &Self::ExprEval,
390	) -> Result<(), Error>;
391}
392
393/// An interface for defining execution kernels.
394///
395/// A _kernel_ is a program that executes synchronously in one thread, with access to
396/// local memory buffers.
397///
398/// See [`ComputeLayer::accumulate_kernels`] for more information.
399pub trait KernelExecutor<F> {
400	/// The type for kernel-local memory buffers.
401	type Mem: ComputeMemory<F>;
402
403	/// The kernel(core)-level operation (scalar) type. This is a promise for a returned value.
404	type Value;
405
406	/// The evaluator for arithmetic expressions (polynomials).
407	type ExprEval: Sync;
408
409	/// Declares a kernel-level value.
410	fn decl_value(&mut self, init: F) -> Result<Self::Value, Error>;
411
412	/// A kernel-local operation that evaluates a composition polynomial over several buffers,
413	/// row-wise, and returns the sum of the evaluations, scaled by a batching coefficient.
414	///
415	/// Mathematically, let there be $m$ input buffers, $P_0, \ldots, P_{m-1}$, each of length
416	/// $2^n$ elements. Let $c$ be the scaling coefficient (`batch_coeff`) and
417	/// $C(X_0, \ldots, X_{m-1})$ be the composition polynomial. The operation computes
418	///
419	/// $$
420	/// \sum_{i=0}^{2^n - 1} c C(P_0\[i\], \ldots, P_{m-1}\[i\]).
421	/// $$
422	///
423	/// The result is added back to an accumulator value.
424	///
425	/// ## Arguments
426	///
427	/// * `log_len` - the binary logarithm of the number of elements in each input buffer.
428	/// * `inputs` - the input buffers. Each row contains the values for a single variable.
429	/// * `composition` - the compiled composition polynomial expression. This is an output of
430	///   [`ComputeLayer::compile_expr`].
431	/// * `batch_coeff` - the scaling coefficient.
432	/// * `accumulator` - the output where the result is accumulated to.
433	fn sum_composition_evals(
434		&mut self,
435		inputs: &SlicesBatch<<Self::Mem as ComputeMemory<F>>::FSlice<'_>>,
436		composition: &Self::ExprEval,
437		batch_coeff: F,
438		accumulator: &mut Self::Value,
439	) -> Result<(), Error>;
440
441	/// A kernel-local operation that performs point-wise addition of two input buffers into an
442	/// output buffer.
443	///
444	/// ## Arguments
445	///
446	/// * `log_len` - the binary logarithm of the number of elements in all three buffers.
447	/// * `src1` - the first input buffer.
448	/// * `src2` - the second input buffer.
449	/// * `dst` - the output buffer that receives the element-wise sum.
450	fn add(
451		&mut self,
452		log_len: usize,
453		src1: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
454		src2: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
455		dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
456	) -> Result<(), Error>;
457}
458
459/// A memory mapping specification for a kernel execution.
460///
461/// See [`ComputeLayer::accumulate_kernels`] for context on kernel execution.
462pub enum KernelMemMap<'a, F, Mem: ComputeMemory<F>> {
463	/// This maps a chunk of a buffer in global device memory to a read-only kernel buffer.
464	Chunked {
465		data: Mem::FSlice<'a>,
466		log_min_chunk_size: usize,
467	},
468	/// This maps a chunk of a mutable buffer in global device memory to a read-write kernel
469	/// buffer. When the kernel exits, the data in the kernel buffer is written back to the
470	/// original location.
471	ChunkedMut {
472		data: Mem::FSliceMut<'a>,
473		log_min_chunk_size: usize,
474	},
475	/// This allocates a kernel-local scratchpad buffer. The size specified in the mapping is the
476	/// total size of all kernel scratchpads. This is so that the kernel's local scratchpad size
477	/// scales up proportionally to the size of chunked buffers.
478	Local { log_size: usize },
479}
480
481impl<'a, F, Mem: ComputeMemory<F>> KernelMemMap<'a, F, Mem> {
482	/// Computes a range of possible number of chunks that data can be split into, given a sequence
483	/// of memory mappings.
484	pub fn log_chunks_range(mappings: &[Self]) -> Option<Range<usize>> {
485		mappings
486			.iter()
487			.map(|mapping| match mapping {
488				Self::Chunked {
489					data,
490					log_min_chunk_size,
491				} => {
492					let log_data_size = checked_log_2(data.len());
493					let log_min_chunk_size = (*log_min_chunk_size)
494						.max(checked_log_2(Mem::ALIGNMENT))
495						.min(log_data_size);
496					0..(log_data_size - log_min_chunk_size)
497				}
498				Self::ChunkedMut {
499					data,
500					log_min_chunk_size,
501				} => {
502					let log_data_size = checked_log_2(data.len());
503					let log_min_chunk_size = (*log_min_chunk_size)
504						.max(checked_log_2(Mem::ALIGNMENT))
505						.min(log_data_size);
506					0..(log_data_size - log_min_chunk_size)
507				}
508				Self::Local { log_size } => 0..*log_size,
509			})
510			.reduce(|range0, range1| range0.start.max(range1.start)..range0.end.min(range1.end))
511	}
512
513	// Split the memory mapping into `1 << log_chunks>` chunks.
514	pub fn chunks(self, log_chunks: usize) -> impl Iterator<Item = KernelMemMap<'a, F, Mem>> {
515		match self {
516			Self::Chunked {
517				data,
518				log_min_chunk_size,
519			} => Either::Left(Either::Left(
520				Mem::slice_chunks(data, checked_int_div(data.len(), 1 << log_chunks)).map(
521					move |data| KernelMemMap::Chunked {
522						data,
523						log_min_chunk_size,
524					},
525				),
526			)),
527			Self::ChunkedMut { data, .. } => {
528				let chunks_count = checked_int_div(data.len(), 1 << log_chunks);
529				Either::Left(Either::Right(Mem::slice_chunks_mut(data, chunks_count).map(
530					move |data| KernelMemMap::ChunkedMut {
531						data,
532						log_min_chunk_size: checked_log_2(chunks_count),
533					},
534				)))
535			}
536			Self::Local { log_size } => Either::Right(
537				std::iter::repeat_with(move || KernelMemMap::Local {
538					log_size: log_size - log_chunks,
539				})
540				.take(1 << log_chunks),
541			),
542		}
543	}
544}
545
546/// A memory buffer mapped into a kernel.
547///
548/// See [`ComputeLayer::accumulate_kernels`] for context on kernel execution.
549pub enum KernelBuffer<'a, F, Mem: ComputeMemory<F>> {
550	Ref(Mem::FSlice<'a>),
551	Mut(Mem::FSliceMut<'a>),
552}
553
554impl<'a, F, Mem: ComputeMemory<F>> KernelBuffer<'a, F, Mem> {
555	/// Returns underlying data as an `FSlice`.
556	pub fn to_ref(&self) -> Mem::FSlice<'_> {
557		match self {
558			Self::Ref(slice) => Mem::narrow(slice),
559			Self::Mut(slice) => Mem::as_const(slice),
560		}
561	}
562}
563
564impl<'a, F, Mem: ComputeMemory<F>> SizedSlice for KernelBuffer<'a, F, Mem> {
565	fn len(&self) -> usize {
566		match self {
567			KernelBuffer::Ref(mem) => mem.len(),
568			KernelBuffer::Mut(mem) => mem.len(),
569		}
570	}
571}
572
573#[derive(Debug, thiserror::Error)]
574pub enum Error {
575	#[error("input validation: {0}")]
576	InputValidation(String),
577	#[error("allocation error: {0}")]
578	Alloc(#[from] AllocError),
579	#[error("device error: {0}")]
580	DeviceError(Box<dyn std::error::Error + Send + Sync + 'static>),
581}
582
583// Convenience types for the device memory.
584pub type FSlice<'a, F, HAL> = <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSlice<'a>;
585pub type FSliceMut<'a, F, HAL> =
586	<<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSliceMut<'a>;
587
588#[cfg(test)]
589mod tests {
590	use assert_matches::assert_matches;
591	use binius_field::{BinaryField128b, Field, TowerField, tower::CanonicalTowerFamily};
592	use rand::{SeedableRng, prelude::StdRng};
593
594	use super::*;
595	use crate::{
596		alloc::{BumpAllocator, ComputeAllocator, Error as AllocError, HostBumpAllocator},
597		cpu::{CpuLayer, CpuMemory},
598	};
599
600	/// Test showing how to allocate host memory and create a sub-allocator over it.
601	fn test_host_alloc<F: TowerField, HAL: ComputeLayer<F>>(hal: HAL) {
602		let mut host_slice = hal.host_alloc(256);
603
604		let bump = HostBumpAllocator::new(host_slice.as_mut());
605		assert_eq!(bump.alloc(100).unwrap().len(), 100);
606		assert_eq!(bump.alloc(100).unwrap().len(), 100);
607		assert_matches!(bump.alloc(100), Err(AllocError::OutOfMemory));
608	}
609
610	/// Test showing how to allocate host memory and create a sub-allocator over it.
611	// TODO: This 'a lifetime bound on HAL is pretty annoying. I'd like to get rid of it.
612	fn test_copy_host_device<'a, F: TowerField, HAL: ComputeLayer<F> + 'a>(
613		hal: HAL,
614		mut dev_mem: FSliceMut<'a, F, HAL>,
615	) {
616		let mut rng = StdRng::seed_from_u64(0);
617
618		let mut host_slice = hal.host_alloc(256);
619
620		let host_alloc = HostBumpAllocator::new(host_slice.as_mut());
621		let dev_alloc = BumpAllocator::<F, HAL::DevMem>::from_ref(&mut dev_mem);
622
623		let host_buf_1 = host_alloc.alloc(128).unwrap();
624		let host_buf_2 = host_alloc.alloc(128).unwrap();
625		let mut dev_buf_1 = dev_alloc.alloc(128).unwrap();
626		let mut dev_buf_2 = dev_alloc.alloc(128).unwrap();
627
628		for elem in &mut *host_buf_1 {
629			*elem = F::random(&mut rng);
630		}
631
632		hal.copy_h2d(host_buf_1, &mut dev_buf_1).unwrap();
633		hal.copy_d2d(HAL::DevMem::as_const(&dev_buf_1), &mut dev_buf_2)
634			.unwrap();
635		hal.copy_d2h(HAL::DevMem::as_const(&dev_buf_2), host_buf_2)
636			.unwrap();
637
638		assert_eq!(host_buf_1, host_buf_2);
639	}
640
641	#[test]
642	fn test_cpu_host_alloc() {
643		test_host_alloc(CpuLayer::<CanonicalTowerFamily>::default());
644	}
645
646	#[test]
647	fn test_cpu_copy_host_device() {
648		let mut dev_mem = vec![BinaryField128b::ZERO; 256];
649		test_copy_host_device(CpuLayer::<CanonicalTowerFamily>::default(), dev_mem.as_mut_slice());
650	}
651
652	#[test]
653	fn test_log_chunks_range() {
654		let mem_1 = vec![BinaryField128b::ZERO; 256];
655		let mut mem_2 = vec![BinaryField128b::ZERO; 256];
656
657		let mappings = vec![
658			KernelMemMap::Chunked {
659				data: mem_1.as_slice(),
660				log_min_chunk_size: 4,
661			},
662			KernelMemMap::ChunkedMut {
663				data: mem_2.as_mut_slice(),
664				log_min_chunk_size: 6,
665			},
666			KernelMemMap::Local { log_size: 8 },
667		];
668
669		let range =
670			KernelMemMap::<BinaryField128b, CpuMemory>::log_chunks_range(&mappings).unwrap();
671		assert_eq!(range.start, 0);
672		assert_eq!(range.end, 2);
673	}
674}