binius_compute/
layer.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field};
6use binius_math::ArithCircuit;
7use binius_ntt::AdditiveNTT;
8use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
9use itertools::Either;
10
11use super::{
12	alloc::Error as AllocError,
13	memory::{ComputeMemory, SubfieldSlice},
14};
15use crate::{
16	alloc::ComputeAllocator,
17	cpu::CpuMemory,
18	memory::{SizedSlice, SlicesBatch},
19};
20
21/// A hardware abstraction layer (HAL) for compute operations.
22pub trait ComputeLayer<F: Field> {
23	/// The device memory.
24	type DevMem: ComputeMemory<F>;
25
26	/// The executor that can execute operations on the device.
27	type Exec<'a>: ComputeLayerExecutor<F, DevMem = Self::DevMem>
28	where
29		Self: 'a;
30
31	/// Copy data from the host to the device.
32	///
33	/// ## Preconditions
34	///
35	/// * `src` and `dst` must have the same length.
36	fn copy_h2d(&self, src: &[F], dst: &mut FSliceMut<'_, F, Self>) -> Result<(), Error>;
37
38	/// Copy data from the device to the host.
39	///
40	/// ## Preconditions
41	///
42	/// * `src` and `dst` must have the same length.
43	fn copy_d2h(&self, src: FSlice<'_, F, Self>, dst: &mut [F]) -> Result<(), Error>;
44
45	/// Copy data between disjoint device buffers.
46	///
47	/// ## Preconditions
48	///
49	/// * `src` and `dst` must have the same length.
50	fn copy_d2d(
51		&self,
52		src: FSlice<'_, F, Self>,
53		dst: &mut FSliceMut<'_, F, Self>,
54	) -> Result<(), Error>;
55
56	/// Compiles an arithmetic expression to the evaluator.
57	fn compile_expr(
58		&self,
59		expr: &ArithCircuit<F>,
60	) -> Result<<Self::Exec<'_> as ComputeLayerExecutor<F>>::ExprEval, Error>;
61
62	/// Executes an operation.
63	///
64	/// A HAL operation is an abstract function that runs with an executor reference.
65	fn execute<'a, 'b>(
66		&'b self,
67		f: impl FnOnce(
68			&mut Self::Exec<'a>,
69		) -> Result<Vec<<Self::Exec<'a> as ComputeLayerExecutor<F>>::OpValue>, Error>,
70	) -> Result<Vec<F>, Error>
71	where
72		'b: 'a;
73
74	/// Fills a mutable slice of field elements with a given value.
75	///
76	/// This operation takes a mutable slice (`FSliceMut<F>`) and a field element `value`,
77	/// and sets each element in the slice to the given value.
78	///
79	/// ### Arguments
80	///
81	/// * `slice` - A mutable slice of field elements to be filled.
82	/// * `value` - The field element used to fill each position in the slice.
83	fn fill(
84		&self,
85		slice: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
86		value: F,
87	) -> Result<(), Error>;
88}
89
90/// An interface for executing a sequence of operations on an accelerated compute device
91///
92/// This component defines a sequence of accelerated data transformations that must appear to
93/// execute in order on the selected compute device. Implementations may defer execution of any
94/// component of the defined sequence under the condition that the store-to-load ordering of the
95/// appended transformations is preserved.
96///
97/// The root [`ComputeLayerExecutor`] is obtained from [`ComputeLayer::execute`]. Nested instances
98/// for parallel and sequential blocks can be obtained via [`ComputeLayerExecutor::join`] and
99/// [`ComputeLayerExecutor::map`] respectively.
100pub trait ComputeLayerExecutor<F: Field> {
101	/// The evaluator for arithmetic expressions (polynomials).
102	type ExprEval: Sync;
103
104	/// The device memory.
105	type DevMem: ComputeMemory<F>;
106
107	/// The operation (scalar) value type.
108	type OpValue: Send;
109
110	/// The executor that can execute operations on a kernel-level granularity (i.e., a single
111	/// core).
112	type KernelExec: KernelExecutor<F, ExprEval = Self::ExprEval>;
113
114	/// Creates an operation that depends on the concurrent execution of two inner operations.
115	fn join<Out1: Send, Out2: Send>(
116		&mut self,
117		op1: impl Send + FnOnce(&mut Self) -> Result<Out1, Error>,
118		op2: impl Send + FnOnce(&mut Self) -> Result<Out2, Error>,
119	) -> Result<(Out1, Out2), Error> {
120		let out1 = op1(self)?;
121		let out2 = op2(self)?;
122		Ok((out1, out2))
123	}
124
125	/// Creates an operation that depends on the concurrent execution of a sequence of operations.
126	fn map<Out: Send, I: ExactSizeIterator<Item: Send> + Send>(
127		&mut self,
128		iter: I,
129		map: impl Sync + Fn(&mut Self, I::Item) -> Result<Out, Error>,
130	) -> Result<Vec<Out>, Error> {
131		iter.map(|item| map(self, item)).collect()
132	}
133
134	/// Launch many kernels in parallel and accumulate the scalar results with field addition.
135	///
136	/// This method provides low-level access to schedule parallel kernel executions on the compute
137	/// platform. A _kernel_ is a program that executes synchronously in one thread, with access to
138	/// local memory buffers. When the environment launches a kernel, it sets up the kernel's local
139	/// memory according to the memory mapping specifications provided by the `mem_maps` parameter.
140	/// The mapped buffers have type [`KernelBuffer`], and they may be read-write or read-only.
141	/// When the kernel exits, it returns a small number of computed values as field elements. The
142	/// vector of returned scalars is accumulated via binary field addition across all kernels and
143	/// returned from the call.
144	///
145	/// This method is fairly general but also designed to fit the specific needs of the sumcheck
146	/// protocol. That motivates the choice that returned values are small-length vectors that are
147	/// accumulated with field addition.
148	///
149	/// ## Buffer chunking
150	///
151	/// The kernel local memory buffers are thought of as slices of a larger buffer, which may or
152	/// may not exist somewhere else. Each kernel operates on a chunk of the larger buffers. For
153	/// example, the [`KernelMemMap::Chunked`] mapping specifies that each kernel operates on a
154	/// read-only chunk of a buffer in global device memory. The [`KernelMemMap::Local`] mapping
155	/// specifies that a kernel operates on a local scratchpad initialized with zeros and discarded
156	/// at the end of kernel execution (sort of like /dev/null).
157	///
158	/// This [`ComputeLayer`] object can decide how many kernels to launch and thus how large
159	/// each kernel's buffer chunks are. The number of chunks must be a power of two. This
160	/// information is provided to the kernel specification closure as an argument.
161	///
162	/// ## Kernel specification
163	///
164	/// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
165	/// has three parameters:
166	///
167	/// * `kernel_exec` - the kernel execution environment.
168	/// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
169	/// * `buffers` - a vector of kernel-local buffers.
170	///
171	/// The closure must respect certain assumptions:
172	///
173	/// * The kernel closure control flow is identical on each invocation when `log_chunks` is
174	///   unchanged.
175	///
176	/// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
177	/// for example with different values for `log_chunks`.
178	///
179	/// ## Arguments
180	///
181	/// * `map` - the kernel specification closure. See the "Kernel specification" section above.
182	/// * `mem_maps` - the memory mappings for the kernel-local buffers.
183	fn accumulate_kernels(
184		&mut self,
185		map: impl Sync
186		+ for<'a> Fn(
187			&'a mut Self::KernelExec,
188			usize,
189			Vec<KernelBuffer<'a, F, <Self::KernelExec as KernelExecutor<F>>::Mem>>,
190		) -> Result<Vec<<Self::KernelExec as KernelExecutor<F>>::Value>, Error>,
191		mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
192	) -> Result<Vec<Self::OpValue>, Error>;
193
194	/// Launch many kernels in parallel to process buffers without accumulating results.
195	///
196	/// Similar to [`Self::accumulate_kernels`], this method provides low-level access to schedule
197	/// parallel kernel executions on the compute platform. The key difference is that this method
198	/// is focused on performing parallel operations on buffers without a reduction phase.
199	/// Each kernel operates on its assigned chunk of data and writes its results directly to
200	/// the mutable buffers provided in the memory mappings.
201	///
202	/// This method is suitable for operations where you need to transform data in parallel
203	/// without aggregating results, such as element-wise transformations of large arrays.
204	///
205	/// ## Buffer chunking
206	///
207	/// The kernel local memory buffers follow the same chunking approach as
208	/// [`Self::accumulate_kernels`]. Each kernel operates on a chunk of the larger buffers as
209	/// specified by the memory mappings.
210	///
211	/// ## Kernel specification
212	///
213	/// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
214	/// has three parameters:
215	///
216	/// * `kernel_exec` - the kernel execution environment.
217	/// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
218	/// * `buffers` - a vector of kernel-local buffers.
219	///
220	/// Unlike [`Self::accumulate_kernels`], this method does not expect the kernel to return any
221	/// values for accumulation. Instead, the kernel should write its results directly to the
222	/// mutable buffers provided in the `buffers` parameter.
223	///
224	/// The closure must respect certain assumptions:
225	///
226	/// * The kernel closure control flow is identical on each invocation when `log_chunks` is
227	///   unchanged.
228	///
229	/// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
230	/// for example with different values for `log_chunks`.
231	///
232	/// ## Arguments
233	///
234	/// * `map` - the kernel specification closure. See the "Kernel specification" section above.
235	/// * `mem_maps` - the memory mappings for the kernel-local buffers.
236	fn map_kernels(
237		&mut self,
238		map: impl Sync
239		+ for<'a> Fn(
240			&'a mut Self::KernelExec,
241			usize,
242			Vec<KernelBuffer<'a, F, <Self::KernelExec as KernelExecutor<F>>::Mem>>,
243		) -> Result<(), Error>,
244		mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
245	) -> Result<(), Error>;
246
247	/// Returns the inner product of a vector of subfield elements with big field elements.
248	///
249	/// ## Arguments
250	///
251	/// * `a_in` - the first input slice of subfield elements.
252	/// * `b_in` - the second input slice of `F` elements.
253	///
254	/// ## Throws
255	///
256	/// * if `tower_level` or `a_in` is greater than `F::TOWER_LEVEL`
257	/// * unless `a_in` and `b_in` contain the same number of elements, and the number is a power of
258	///   two
259	///
260	/// ## Returns
261	///
262	/// Returns the inner product of `a_in` and `b_in`.
263	fn inner_product(
264		&mut self,
265		a_in: SubfieldSlice<'_, F, Self::DevMem>,
266		b_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
267	) -> Result<Self::OpValue, Error>;
268
269	/// Computes the iterative tensor product of the input with the given coordinates.
270	///
271	/// This operation modifies the data buffer in place.
272	///
273	/// ## Mathematical Definition
274	///
275	/// This operation accepts parameters
276	///
277	/// * $n \in \mathbb{N}$ (`log_n`),
278	/// * $k \in \mathbb{N}$ (`coordinates.len()`),
279	/// * $v \in L^{2^n}$ (`data[..1 << log_n]`),
280	/// * $r \in L^k$ (`coordinates`),
281	///
282	/// and computes the vector
283	///
284	/// $$
285	/// v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})
286	/// $$
287	///
288	/// ## Throws
289	///
290	/// * unless `2**(log_n + coordinates.len())` equals `data.len()`
291	fn tensor_expand(
292		&mut self,
293		log_n: usize,
294		coordinates: &[F],
295		data: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
296	) -> Result<(), Error>;
297
298	/// Computes left matrix-vector multiplication of a subfield matrix with a big field vector.
299	///
300	/// ## Mathematical Definition
301	///
302	/// This operation accepts
303	///
304	/// * $n \in \mathbb{N}$ (`out.len()`),
305	/// * $m \in \mathbb{N}$ (`vec.len()`),
306	/// * $M \in K^{n \times m}$ (`mat`),
307	/// * $v \in K^m$ (`vec`),
308	///
309	/// and computes the vector $Mv$.
310	///
311	/// ## Args
312	///
313	/// * `mat` - a slice of elements from a subfield of `F`.
314	/// * `vec` - a slice of `F` elements.
315	/// * `out` - a buffer for the output vector of `F` elements.
316	///
317	/// ## Throws
318	///
319	/// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
320	/// * Returns an error if `mat` is not a subfield of `F`.
321	fn fold_left(
322		&mut self,
323		mat: SubfieldSlice<'_, F, Self::DevMem>,
324		vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
325		out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
326	) -> Result<(), Error>;
327
328	/// Computes right matrix-vector multiplication of a subfield matrix with a big field vector.
329	///
330	/// ## Mathematical Definition
331	///
332	/// This operation accepts
333	///
334	/// * $n \in \mathbb{N}$ (`vec.len()`),
335	/// * $m \in \mathbb{N}$ (`out.len()`),
336	/// * $M \in K^{n \times m}$ (`mat`),
337	/// * $v \in K^m$ (`vec`),
338	///
339	/// and computes the vector $((v')M)'$. The prime denotes a transpose
340	///
341	/// ## Args
342	///
343	/// * `mat` - a slice of elements from a subfield of `F`.
344	/// * `vec` - a slice of `F` elements.
345	/// * `out` - a buffer for the output vector of `F` elements.
346	///
347	/// ## Throws
348	///
349	/// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
350	/// * Returns an error if `mat` is not a subfield of `F`.
351	fn fold_right(
352		&mut self,
353		mat: SubfieldSlice<'_, F, Self::DevMem>,
354		vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
355		out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
356	) -> Result<(), Error>;
357
358	/// FRI-fold the interleaved codeword using the given challenges.
359	///
360	/// The FRI-fold operation folds a length $2^{n+b+\eta}$ vector of field elements into a length
361	/// $2^n$ vector of field elements. $n$ is the log block length of the code, $b$ is the log
362	/// batch size, and $b + \eta$ is the number of challenge elements. The operation has the
363	/// following mathematical structure:
364	///
365	/// 1. Split the challenge vector into two parts: $c_0$ with length $b$ and $c_1$ with length
366	///    $\eta$.
367	/// 2. Low fold the input data with the tensor expansion of $c_0.
368	/// 3. Apply $\eta$ layers of the inverse additive NTT to the data.
369	/// 4. Low fold the input data with the tensor expansion of $c_1.
370	///
371	/// The algorithm to perform steps 3 and 4 can be combined into a linear amount of work,
372	/// whereas step 3 on its own would require $\eta$ independent linear passes.
373	///
374	/// See [DP24], Section 4.2 for more details.
375	///
376	/// This operation writes the result out-of-place into an output buffer.
377	///
378	/// ## Arguments
379	///
380	/// * `ntt` - the NTT instance, used to look up the twiddle values.
381	/// * `log_len` - $n + \eta$, the binary logarithm of the code length.
382	/// * `log_batch_size` - $b$, the binary logarithm of the interleaved code batch size.
383	/// * `challenges` - the folding challenges, with length $b + \eta$.
384	/// * `data_in` - an input vector, with length $2^{n + b + \eta}$.
385	/// * `data_out` - an output buffer, with length $2^n$.
386	///
387	/// [DP24]: <https://eprint.iacr.org/2024/504>
388	#[allow(clippy::too_many_arguments)]
389	fn fri_fold<FSub>(
390		&mut self,
391		ntt: &(impl AdditiveNTT<FSub> + Sync),
392		log_len: usize,
393		log_batch_size: usize,
394		challenges: &[F],
395		data_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
396		data_out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
397	) -> Result<(), Error>
398	where
399		FSub: BinaryField,
400		F: ExtensionField<FSub>;
401
402	/// Extrapolates a line between a vector of evaluations at 0 and evaluations at 1.
403	///
404	/// Given two values $y_0, y_1$, this operation computes the value $y_z = y_0 + (y_1 - y_0) z$,
405	/// which is the value of the line that interpolates $(0, y_0), (1, y_1)$ at $z$. This computes
406	/// this operation in parallel over two vectors of big field elements of equal sizes.
407	///
408	/// The operation writes the result back in-place into the `evals_0` buffer.
409	///
410	/// ## Args
411	///
412	/// * `evals_0` - this is both an input and output buffer. As in input, it is populated with the
413	///   values $y_0$, which are the line's values at 0.
414	/// * `evals_1` - an input buffer with the values $y_1$, which are the line's values at 1.
415	/// * `z` - the scalar evaluation point.
416	///
417	/// ## Throws
418	///
419	/// * if `evals_0` and `evals_1` are not equal sizes.
420	/// * if the sizes of `evals_0` and `evals_1` are not powers of two.
421	fn extrapolate_line(
422		&mut self,
423		evals_0: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
424		evals_1: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
425		z: F,
426	) -> Result<(), Error>;
427
428	/// Computes the elementwise application of a compiled arithmetic expression to multiple input
429	/// slices.
430	///
431	/// This operation applies the composition expression to each row of input values, where a row
432	/// consists of one element from each input slice at the same index position. The results are
433	/// stored in the output slice.
434	///
435	/// ## Mathematical Definition
436	///
437	/// Given:
438	/// - Multiple input slices $P_0, \ldots, P_{m-1}$, each of length $2^n$ elements
439	/// - A composition function $C(X_0, \ldots, X_{m-1})$
440	/// - An output slice $P_{\text{out}}$ of length $2^n$ elements
441	///
442	/// This operation computes:
443	///
444	/// $$
445	/// P_{\text{out}}\[i\] = C(P_0\[i\], \ldots, P_{m-1}\[i\])
446	/// \quad \forall i \in \{0, \ldots, 2^n- 1\}
447	/// $$
448	///
449	/// ## Arguments
450	///
451	/// * `inputs` - A slice of input slices, where each slice contains field elements.
452	/// * `output` - A mutable output slice where the results will be stored.
453	/// * `composition` - The compiled arithmetic expression to apply.
454	///
455	/// ## Throws
456	///
457	/// * Returns an error if any input or output slice has a length that is not a power of two.
458	/// * Returns an error if the input and output slices do not all have the same length.
459	fn compute_composite(
460		&mut self,
461		inputs: &SlicesBatch<<Self::DevMem as ComputeMemory<F>>::FSlice<'_>>,
462		output: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
463		composition: &Self::ExprEval,
464	) -> Result<(), Error>;
465
466	/// Reduces a slice of elements to a single value by recursively applying pairwise
467	/// multiplication.
468	///
469	/// Given an input slice `x` of length `n = 2^k` for some integer `k`,
470	/// this function computes the result:
471	///
472	/// $$
473	/// y = \prod_{i=0}^{n-1} x_i
474	/// $$
475	///
476	/// However, instead of a flat left-to-right reduction, the computation proceeds
477	/// in ⌈log₂(n)⌉ rounds, halving the number of elements each time:
478	///
479	/// - Round 0: $$ x_{i,0} = x_{2i} \cdot x_{2i+1} \quad \text{for } i = 0 \ldots \frac{n}{2} - 1
480	///   $$
481	///
482	/// - Round 1: $$ x_{i,1} = x_{2i,0} \cdot x_{2i+1,0} $$
483	///
484	/// - ...
485	///
486	/// - Final round: $$ y = x_{0,k} = \prod_{i=0}^{n-1} x_i $$
487	///
488	/// This binary tree-style reduction is mathematically equivalent to the full product,
489	/// but structured for efficient parallelization.
490	///
491	/// ## Arguments
492	///
493	/// * `input`` - A slice of input field elements provided to the first reduction round
494	/// * `round_outputs` - A mutable slice of preallocated output field elements for each reduction
495	///   round. `round_outputs.len()` must equal log₂(input.len()) - 1. The length of the FSlice at
496	///   index i must equal input.len() / 2**(i + 1) for i in 0..round_outputs.len().
497	///
498	/// ## Throws
499	///
500	/// * Returns an error if the length of `input` is not a power of 2.
501	/// * Returns an error if the length of `input` is less than 2 (no reductions are possible).
502	/// * Returns an error if `round_outputs.len()` != log₂(input.len())
503	/// * Returns an error if any element in `round_outputs` does not satisfy
504	///   `round_outputs[i].len() == input.len() / 2**(i + 1)` for i in 0..round_outputs.len()
505	fn pairwise_product_reduce(
506		&mut self,
507		input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
508		round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
509	) -> Result<(), Error>;
510}
511
512/// An interface for defining execution kernels.
513///
514/// A _kernel_ is a program that executes synchronously in one thread, with access to
515/// local memory buffers.
516///
517/// See [`ComputeLayerExecutor::accumulate_kernels`] for more information.
518pub trait KernelExecutor<F> {
519	/// The type for kernel-local memory buffers.
520	type Mem: ComputeMemory<F>;
521
522	/// The kernel(core)-level operation (scalar) type. This is a promise for a returned value.
523	type Value;
524
525	/// The evaluator for arithmetic expressions (polynomials).
526	type ExprEval: Sync;
527
528	/// Declares a kernel-level value.
529	fn decl_value(&mut self, init: F) -> Result<Self::Value, Error>;
530
531	/// A kernel-local operation that evaluates a composition polynomial over several buffers,
532	/// row-wise, and returns the sum of the evaluations, scaled by a batching coefficient.
533	///
534	/// Mathematically, let there be $m$ input buffers, $P_0, \ldots, P_{m-1}$, each of length
535	/// $2^n$ elements. Let $c$ be the scaling coefficient (`batch_coeff`) and
536	/// $C(X_0, \ldots, X_{m-1})$ be the composition polynomial. The operation computes
537	///
538	/// $$
539	/// \sum_{i=0}^{2^n - 1} c C(P_0\[i\], \ldots, P_{m-1}\[i\]).
540	/// $$
541	///
542	/// The result is added back to an accumulator value.
543	///
544	/// ## Arguments
545	///
546	/// * `log_len` - the binary logarithm of the number of elements in each input buffer.
547	/// * `inputs` - the input buffers. Each row contains the values for a single variable.
548	/// * `composition` - the compiled composition polynomial expression. This is an output of
549	///   [`ComputeLayer::compile_expr`].
550	/// * `batch_coeff` - the scaling coefficient.
551	/// * `accumulator` - the output where the result is accumulated to.
552	fn sum_composition_evals(
553		&mut self,
554		inputs: &SlicesBatch<<Self::Mem as ComputeMemory<F>>::FSlice<'_>>,
555		composition: &Self::ExprEval,
556		batch_coeff: F,
557		accumulator: &mut Self::Value,
558	) -> Result<(), Error>;
559
560	/// A kernel-local operation that performs point-wise addition of two input buffers into an
561	/// output buffer.
562	///
563	/// ## Arguments
564	///
565	/// * `log_len` - the binary logarithm of the number of elements in all three buffers.
566	/// * `src1` - the first input buffer.
567	/// * `src2` - the second input buffer.
568	/// * `dst` - the output buffer that receives the element-wise sum.
569	fn add(
570		&mut self,
571		log_len: usize,
572		src1: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
573		src2: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
574		dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
575	) -> Result<(), Error>;
576
577	/// A kernel-local operation that adds a source buffer into a destination buffer, in place.
578	///
579	/// ## Arguments
580	///
581	/// * `log_len` - the binary logarithm of the number of elements in the two buffers.
582	/// * `src` - the source buffer.
583	/// * `dst` - the destination buffer.
584	fn add_assign(
585		&mut self,
586		log_len: usize,
587		src: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
588		dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
589	) -> Result<(), Error>;
590}
591
592/// A memory mapping specification for a kernel execution.
593///
594/// See [`ComputeLayerExecutor::accumulate_kernels`] for context on kernel execution.
595pub enum KernelMemMap<'a, F, Mem: ComputeMemory<F>> {
596	/// This maps a chunk of a buffer in global device memory to a read-only kernel buffer.
597	Chunked {
598		data: Mem::FSlice<'a>,
599		log_min_chunk_size: usize,
600	},
601	/// This maps a chunk of a mutable buffer in global device memory to a read-write kernel
602	/// buffer. When the kernel exits, the data in the kernel buffer is written back to the
603	/// original location.
604	ChunkedMut {
605		data: Mem::FSliceMut<'a>,
606		log_min_chunk_size: usize,
607	},
608	/// This allocates a kernel-local scratchpad buffer. The size specified in the mapping is the
609	/// total size of all kernel scratchpads. This is so that the kernel's local scratchpad size
610	/// scales up proportionally to the size of chunked buffers.
611	Local { log_size: usize },
612}
613
614impl<'a, F, Mem: ComputeMemory<F>> KernelMemMap<'a, F, Mem> {
615	/// Computes a range of possible number of chunks that data can be split into, given a sequence
616	/// of memory mappings.
617	pub fn log_chunks_range(mappings: &[Self]) -> Option<Range<usize>> {
618		mappings
619			.iter()
620			.map(|mapping| match mapping {
621				Self::Chunked {
622					data,
623					log_min_chunk_size,
624				} => {
625					let log_data_size = checked_log_2(data.len());
626					let log_min_chunk_size = (*log_min_chunk_size)
627						.max(checked_log_2(Mem::ALIGNMENT))
628						.min(log_data_size);
629					0..(log_data_size - log_min_chunk_size)
630				}
631				Self::ChunkedMut {
632					data,
633					log_min_chunk_size,
634				} => {
635					let log_data_size = checked_log_2(data.len());
636					let log_min_chunk_size = (*log_min_chunk_size)
637						.max(checked_log_2(Mem::ALIGNMENT))
638						.min(log_data_size);
639					0..(log_data_size - log_min_chunk_size)
640				}
641				Self::Local { log_size } => 0..*log_size,
642			})
643			.reduce(|range0, range1| range0.start.max(range1.start)..range0.end.min(range1.end))
644	}
645
646	// Split the memory mapping into `1 << log_chunks>` chunks.
647	pub fn chunks(self, log_chunks: usize) -> impl Iterator<Item = KernelMemMap<'a, F, Mem>> {
648		match self {
649			Self::Chunked {
650				data,
651				log_min_chunk_size,
652			} => Either::Left(Either::Left(
653				Mem::slice_chunks(data, checked_int_div(data.len(), 1 << log_chunks)).map(
654					move |data| KernelMemMap::Chunked {
655						data,
656						log_min_chunk_size,
657					},
658				),
659			)),
660			Self::ChunkedMut { data, .. } => {
661				let chunks_count = checked_int_div(data.len(), 1 << log_chunks);
662				Either::Left(Either::Right(Mem::slice_chunks_mut(data, chunks_count).map(
663					move |data| KernelMemMap::ChunkedMut {
664						data,
665						log_min_chunk_size: checked_log_2(chunks_count),
666					},
667				)))
668			}
669			Self::Local { log_size } => Either::Right(
670				std::iter::repeat_with(move || KernelMemMap::Local {
671					log_size: log_size - log_chunks,
672				})
673				.take(1 << log_chunks),
674			),
675		}
676	}
677}
678
679/// A memory buffer mapped into a kernel.
680///
681/// See [`ComputeLayerExecutor::accumulate_kernels`] for context on kernel execution.
682pub enum KernelBuffer<'a, F, Mem: ComputeMemory<F>> {
683	Ref(Mem::FSlice<'a>),
684	Mut(Mem::FSliceMut<'a>),
685}
686
687impl<'a, F, Mem: ComputeMemory<F>> KernelBuffer<'a, F, Mem> {
688	/// Returns underlying data as an `FSlice`.
689	pub fn to_ref(&self) -> Mem::FSlice<'_> {
690		match self {
691			Self::Ref(slice) => Mem::narrow(slice),
692			Self::Mut(slice) => Mem::as_const(slice),
693		}
694	}
695}
696
697impl<'a, F, Mem: ComputeMemory<F>> SizedSlice for KernelBuffer<'a, F, Mem> {
698	fn len(&self) -> usize {
699		match self {
700			KernelBuffer::Ref(mem) => mem.len(),
701			KernelBuffer::Mut(mem) => mem.len(),
702		}
703	}
704}
705
706#[derive(Debug, thiserror::Error)]
707pub enum Error {
708	#[error("input validation: {0}")]
709	InputValidation(String),
710	#[error("allocation error: {0}")]
711	Alloc(#[from] AllocError),
712	#[error("device error: {0}")]
713	DeviceError(Box<dyn std::error::Error + Send + Sync + 'static>),
714	#[error("core library error: {0}")]
715	CoreLibError(Box<dyn std::error::Error + Send + Sync + 'static>),
716}
717
718// Convenience types for the device memory.
719pub type FSlice<'a, F, HAL> = <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSlice<'a>;
720pub type FSliceMut<'a, F, HAL> =
721	<<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSliceMut<'a>;
722
723pub type KernelMem<'a, F, HAL> = <<<HAL as ComputeLayer<F>>::Exec<'a> as ComputeLayerExecutor<F>>::KernelExec as KernelExecutor<F>>::Mem;
724pub type KernelSlice<'a, 'b, F, HAL> = <KernelMem<'b, F, HAL> as ComputeMemory<F>>::FSlice<'a>;
725pub type KernelSliceMut<'a, 'b, F, HAL> =
726	<KernelMem<'b, F, HAL> as ComputeMemory<F>>::FSliceMut<'a>;
727
728/// This is a trait for a holder type for the popular triple:
729/// * a compute layer (HAL),
730/// * a host memory allocator,
731/// * a device memory allocator.
732pub trait ComputeHolder<F: Field, HAL: ComputeLayer<F>> {
733	type HostComputeAllocator<'a>: ComputeAllocator<F, CpuMemory>
734	where
735		Self: 'a;
736	type DeviceComputeAllocator<'a>: ComputeAllocator<F, HAL::DevMem>
737	where
738		Self: 'a;
739
740	fn to_data<'a, 'b>(
741		&'a mut self,
742	) -> ComputeData<'a, F, HAL, Self::HostComputeAllocator<'b>, Self::DeviceComputeAllocator<'b>>
743	where
744		'a: 'b;
745}
746
747pub struct ComputeData<'a, F: Field, HAL: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>
748where
749	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
750	DeviceAllocatorType: ComputeAllocator<F, HAL::DevMem>,
751{
752	pub hal: &'a HAL,
753	pub host_alloc: HostAllocatorType,
754	pub dev_alloc: DeviceAllocatorType,
755	_phantom_data: PhantomData<F>,
756}
757
758impl<'a, F: Field, HAL: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>
759	ComputeData<'a, F, HAL, HostAllocatorType, DeviceAllocatorType>
760where
761	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
762	DeviceAllocatorType: ComputeAllocator<F, HAL::DevMem>,
763{
764	pub fn new(
765		hal: &'a HAL,
766		host_alloc: HostAllocatorType,
767		dev_alloc: DeviceAllocatorType,
768	) -> Self {
769		Self {
770			hal,
771			host_alloc,
772			dev_alloc,
773			_phantom_data: PhantomData::<F>,
774		}
775	}
776}
777
778#[cfg(test)]
779mod tests {
780	use binius_field::{BinaryField128b, Field, TowerField};
781	use binius_math::B128;
782	use rand::{SeedableRng, prelude::StdRng};
783
784	use super::*;
785	use crate::{
786		alloc::ComputeAllocator,
787		cpu::{CpuMemory, layer::CpuLayerHolder},
788	};
789
790	/// Test showing how to allocate host memory and create a sub-allocator over it.
791	// TODO: This 'a lifetime bound on HAL is pretty annoying. I'd like to get rid of it.
792	fn test_copy_host_device<'a, F: TowerField, HAL: ComputeLayer<F> + 'a>(
793		mut compute_holder: impl ComputeHolder<F, HAL>,
794	) {
795		let ComputeData {
796			hal,
797			host_alloc,
798			dev_alloc,
799			_phantom_data,
800		} = compute_holder.to_data();
801
802		let mut rng = StdRng::seed_from_u64(0);
803
804		let host_buf_1 = host_alloc.alloc(128).unwrap();
805		let host_buf_2 = host_alloc.alloc(128).unwrap();
806		let mut dev_buf_1 = dev_alloc.alloc(128).unwrap();
807		let mut dev_buf_2 = dev_alloc.alloc(128).unwrap();
808
809		for elem in host_buf_1.iter_mut() {
810			*elem = F::random(&mut rng);
811		}
812
813		hal.copy_h2d(host_buf_1, &mut dev_buf_1).unwrap();
814		hal.copy_d2d(HAL::DevMem::as_const(&dev_buf_1), &mut dev_buf_2)
815			.unwrap();
816		hal.copy_d2h(HAL::DevMem::as_const(&dev_buf_2), host_buf_2)
817			.unwrap();
818
819		assert_eq!(host_buf_1, host_buf_2);
820	}
821
822	#[test]
823	fn test_cpu_copy_host_device() {
824		test_copy_host_device(CpuLayerHolder::<B128>::new(512, 256));
825	}
826
827	#[test]
828	fn test_log_chunks_range() {
829		let mem_1 = vec![BinaryField128b::ZERO; 256];
830		let mut mem_2 = vec![BinaryField128b::ZERO; 256];
831
832		let mappings = vec![
833			KernelMemMap::Chunked {
834				data: mem_1.as_slice(),
835				log_min_chunk_size: 4,
836			},
837			KernelMemMap::ChunkedMut {
838				data: mem_2.as_mut_slice(),
839				log_min_chunk_size: 6,
840			},
841			KernelMemMap::Local { log_size: 8 },
842		];
843
844		let range =
845			KernelMemMap::<BinaryField128b, CpuMemory>::log_chunks_range(&mappings).unwrap();
846		assert_eq!(range.start, 0);
847		assert_eq!(range.end, 2);
848	}
849}