binius_compute/layer.rs
1// Copyright 2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field};
6use binius_math::ArithCircuit;
7use binius_ntt::AdditiveNTT;
8use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
9use itertools::Either;
10
11use super::{
12 alloc::Error as AllocError,
13 memory::{ComputeMemory, SubfieldSlice},
14};
15use crate::{
16 alloc::ComputeAllocator,
17 cpu::CpuMemory,
18 memory::{SizedSlice, SlicesBatch},
19};
20
21/// A hardware abstraction layer (HAL) for compute operations.
22pub trait ComputeLayer<F: Field> {
23 /// The device memory.
24 type DevMem: ComputeMemory<F>;
25
26 /// The executor that can execute operations on the device.
27 type Exec<'a>: ComputeLayerExecutor<F, DevMem = Self::DevMem>
28 where
29 Self: 'a;
30
31 /// Copy data from the host to the device.
32 ///
33 /// ## Preconditions
34 ///
35 /// * `src` and `dst` must have the same length.
36 fn copy_h2d(&self, src: &[F], dst: &mut FSliceMut<'_, F, Self>) -> Result<(), Error>;
37
38 /// Copy data from the device to the host.
39 ///
40 /// ## Preconditions
41 ///
42 /// * `src` and `dst` must have the same length.
43 fn copy_d2h(&self, src: FSlice<'_, F, Self>, dst: &mut [F]) -> Result<(), Error>;
44
45 /// Copy data between disjoint device buffers.
46 ///
47 /// ## Preconditions
48 ///
49 /// * `src` and `dst` must have the same length.
50 fn copy_d2d(
51 &self,
52 src: FSlice<'_, F, Self>,
53 dst: &mut FSliceMut<'_, F, Self>,
54 ) -> Result<(), Error>;
55
56 /// Compiles an arithmetic expression to the evaluator.
57 fn compile_expr(
58 &self,
59 expr: &ArithCircuit<F>,
60 ) -> Result<<Self::Exec<'_> as ComputeLayerExecutor<F>>::ExprEval, Error>;
61
62 /// Executes an operation.
63 ///
64 /// A HAL operation is an abstract function that runs with an executor reference.
65 fn execute<'a, 'b>(
66 &'b self,
67 f: impl FnOnce(
68 &mut Self::Exec<'a>,
69 ) -> Result<Vec<<Self::Exec<'a> as ComputeLayerExecutor<F>>::OpValue>, Error>,
70 ) -> Result<Vec<F>, Error>
71 where
72 'b: 'a;
73
74 /// Fills a mutable slice of field elements with a given value.
75 ///
76 /// This operation takes a mutable slice (`FSliceMut<F>`) and a field element `value`,
77 /// and sets each element in the slice to the given value.
78 ///
79 /// ### Arguments
80 ///
81 /// * `slice` - A mutable slice of field elements to be filled.
82 /// * `value` - The field element used to fill each position in the slice.
83 fn fill(
84 &self,
85 slice: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
86 value: F,
87 ) -> Result<(), Error>;
88}
89
90/// An interface for executing a sequence of operations on an accelerated compute device
91///
92/// This component defines a sequence of accelerated data transformations that must appear to
93/// execute in order on the selected compute device. Implementations may defer execution of any
94/// component of the defined sequence under the condition that the store-to-load ordering of the
95/// appended transformations is preserved.
96///
97/// The root [`ComputeLayerExecutor`] is obtained from [`ComputeLayer::execute`]. Nested instances
98/// for parallel and sequential blocks can be obtained via [`ComputeLayerExecutor::join`] and
99/// [`ComputeLayerExecutor::map`] respectively.
100pub trait ComputeLayerExecutor<F: Field> {
101 /// The evaluator for arithmetic expressions (polynomials).
102 type ExprEval: Sync;
103
104 /// The device memory.
105 type DevMem: ComputeMemory<F>;
106
107 /// The operation (scalar) value type.
108 type OpValue: Send;
109
110 /// The executor that can execute operations on a kernel-level granularity (i.e., a single
111 /// core).
112 type KernelExec: KernelExecutor<F, ExprEval = Self::ExprEval>;
113
114 /// Creates an operation that depends on the concurrent execution of two inner operations.
115 fn join<Out1: Send, Out2: Send>(
116 &mut self,
117 op1: impl Send + FnOnce(&mut Self) -> Result<Out1, Error>,
118 op2: impl Send + FnOnce(&mut Self) -> Result<Out2, Error>,
119 ) -> Result<(Out1, Out2), Error> {
120 let out1 = op1(self)?;
121 let out2 = op2(self)?;
122 Ok((out1, out2))
123 }
124
125 /// Creates an operation that depends on the concurrent execution of a sequence of operations.
126 fn map<Out: Send, I: ExactSizeIterator<Item: Send> + Send>(
127 &mut self,
128 iter: I,
129 map: impl Sync + Fn(&mut Self, I::Item) -> Result<Out, Error>,
130 ) -> Result<Vec<Out>, Error> {
131 iter.map(|item| map(self, item)).collect()
132 }
133
134 /// Launch many kernels in parallel and accumulate the scalar results with field addition.
135 ///
136 /// This method provides low-level access to schedule parallel kernel executions on the compute
137 /// platform. A _kernel_ is a program that executes synchronously in one thread, with access to
138 /// local memory buffers. When the environment launches a kernel, it sets up the kernel's local
139 /// memory according to the memory mapping specifications provided by the `mem_maps` parameter.
140 /// The mapped buffers have type [`KernelBuffer`], and they may be read-write or read-only.
141 /// When the kernel exits, it returns a small number of computed values as field elements. The
142 /// vector of returned scalars is accumulated via binary field addition across all kernels and
143 /// returned from the call.
144 ///
145 /// This method is fairly general but also designed to fit the specific needs of the sumcheck
146 /// protocol. That motivates the choice that returned values are small-length vectors that are
147 /// accumulated with field addition.
148 ///
149 /// ## Buffer chunking
150 ///
151 /// The kernel local memory buffers are thought of as slices of a larger buffer, which may or
152 /// may not exist somewhere else. Each kernel operates on a chunk of the larger buffers. For
153 /// example, the [`KernelMemMap::Chunked`] mapping specifies that each kernel operates on a
154 /// read-only chunk of a buffer in global device memory. The [`KernelMemMap::Local`] mapping
155 /// specifies that a kernel operates on a local scratchpad initialized with zeros and discarded
156 /// at the end of kernel execution (sort of like /dev/null).
157 ///
158 /// This [`ComputeLayer`] object can decide how many kernels to launch and thus how large
159 /// each kernel's buffer chunks are. The number of chunks must be a power of two. This
160 /// information is provided to the kernel specification closure as an argument.
161 ///
162 /// ## Kernel specification
163 ///
164 /// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
165 /// has three parameters:
166 ///
167 /// * `kernel_exec` - the kernel execution environment.
168 /// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
169 /// * `buffers` - a vector of kernel-local buffers.
170 ///
171 /// The closure must respect certain assumptions:
172 ///
173 /// * The kernel closure control flow is identical on each invocation when `log_chunks` is
174 /// unchanged.
175 ///
176 /// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
177 /// for example with different values for `log_chunks`.
178 ///
179 /// ## Arguments
180 ///
181 /// * `map` - the kernel specification closure. See the "Kernel specification" section above.
182 /// * `mem_maps` - the memory mappings for the kernel-local buffers.
183 fn accumulate_kernels(
184 &mut self,
185 map: impl Sync
186 + for<'a> Fn(
187 &'a mut Self::KernelExec,
188 usize,
189 Vec<KernelBuffer<'a, F, <Self::KernelExec as KernelExecutor<F>>::Mem>>,
190 ) -> Result<Vec<<Self::KernelExec as KernelExecutor<F>>::Value>, Error>,
191 mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
192 ) -> Result<Vec<Self::OpValue>, Error>;
193
194 /// Launch many kernels in parallel to process buffers without accumulating results.
195 ///
196 /// Similar to [`Self::accumulate_kernels`], this method provides low-level access to schedule
197 /// parallel kernel executions on the compute platform. The key difference is that this method
198 /// is focused on performing parallel operations on buffers without a reduction phase.
199 /// Each kernel operates on its assigned chunk of data and writes its results directly to
200 /// the mutable buffers provided in the memory mappings.
201 ///
202 /// This method is suitable for operations where you need to transform data in parallel
203 /// without aggregating results, such as element-wise transformations of large arrays.
204 ///
205 /// ## Buffer chunking
206 ///
207 /// The kernel local memory buffers follow the same chunking approach as
208 /// [`Self::accumulate_kernels`]. Each kernel operates on a chunk of the larger buffers as
209 /// specified by the memory mappings.
210 ///
211 /// ## Kernel specification
212 ///
213 /// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
214 /// has three parameters:
215 ///
216 /// * `kernel_exec` - the kernel execution environment.
217 /// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
218 /// * `buffers` - a vector of kernel-local buffers.
219 ///
220 /// Unlike [`Self::accumulate_kernels`], this method does not expect the kernel to return any
221 /// values for accumulation. Instead, the kernel should write its results directly to the
222 /// mutable buffers provided in the `buffers` parameter.
223 ///
224 /// The closure must respect certain assumptions:
225 ///
226 /// * The kernel closure control flow is identical on each invocation when `log_chunks` is
227 /// unchanged.
228 ///
229 /// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
230 /// for example with different values for `log_chunks`.
231 ///
232 /// ## Arguments
233 ///
234 /// * `map` - the kernel specification closure. See the "Kernel specification" section above.
235 /// * `mem_maps` - the memory mappings for the kernel-local buffers.
236 fn map_kernels(
237 &mut self,
238 map: impl Sync
239 + for<'a> Fn(
240 &'a mut Self::KernelExec,
241 usize,
242 Vec<KernelBuffer<'a, F, <Self::KernelExec as KernelExecutor<F>>::Mem>>,
243 ) -> Result<(), Error>,
244 mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
245 ) -> Result<(), Error>;
246
247 /// Returns the inner product of a vector of subfield elements with big field elements.
248 ///
249 /// ## Arguments
250 ///
251 /// * `a_in` - the first input slice of subfield elements.
252 /// * `b_in` - the second input slice of `F` elements.
253 ///
254 /// ## Throws
255 ///
256 /// * if `tower_level` or `a_in` is greater than `F::TOWER_LEVEL`
257 /// * unless `a_in` and `b_in` contain the same number of elements, and the number is a power of
258 /// two
259 ///
260 /// ## Returns
261 ///
262 /// Returns the inner product of `a_in` and `b_in`.
263 fn inner_product(
264 &mut self,
265 a_in: SubfieldSlice<'_, F, Self::DevMem>,
266 b_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
267 ) -> Result<Self::OpValue, Error>;
268
269 /// Computes the iterative tensor product of the input with the given coordinates.
270 ///
271 /// This operation modifies the data buffer in place.
272 ///
273 /// ## Mathematical Definition
274 ///
275 /// This operation accepts parameters
276 ///
277 /// * $n \in \mathbb{N}$ (`log_n`),
278 /// * $k \in \mathbb{N}$ (`coordinates.len()`),
279 /// * $v \in L^{2^n}$ (`data[..1 << log_n]`),
280 /// * $r \in L^k$ (`coordinates`),
281 ///
282 /// and computes the vector
283 ///
284 /// $$
285 /// v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})
286 /// $$
287 ///
288 /// ## Throws
289 ///
290 /// * unless `2**(log_n + coordinates.len())` equals `data.len()`
291 fn tensor_expand(
292 &mut self,
293 log_n: usize,
294 coordinates: &[F],
295 data: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
296 ) -> Result<(), Error>;
297
298 /// Computes left matrix-vector multiplication of a subfield matrix with a big field vector.
299 ///
300 /// ## Mathematical Definition
301 ///
302 /// This operation accepts
303 ///
304 /// * $n \in \mathbb{N}$ (`out.len()`),
305 /// * $m \in \mathbb{N}$ (`vec.len()`),
306 /// * $M \in K^{n \times m}$ (`mat`),
307 /// * $v \in K^m$ (`vec`),
308 ///
309 /// and computes the vector $Mv$.
310 ///
311 /// ## Args
312 ///
313 /// * `mat` - a slice of elements from a subfield of `F`.
314 /// * `vec` - a slice of `F` elements.
315 /// * `out` - a buffer for the output vector of `F` elements.
316 ///
317 /// ## Throws
318 ///
319 /// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
320 /// * Returns an error if `mat` is not a subfield of `F`.
321 fn fold_left(
322 &mut self,
323 mat: SubfieldSlice<'_, F, Self::DevMem>,
324 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
325 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
326 ) -> Result<(), Error>;
327
328 /// Computes right matrix-vector multiplication of a subfield matrix with a big field vector.
329 ///
330 /// ## Mathematical Definition
331 ///
332 /// This operation accepts
333 ///
334 /// * $n \in \mathbb{N}$ (`vec.len()`),
335 /// * $m \in \mathbb{N}$ (`out.len()`),
336 /// * $M \in K^{n \times m}$ (`mat`),
337 /// * $v \in K^m$ (`vec`),
338 ///
339 /// and computes the vector $((v')M)'$. The prime denotes a transpose
340 ///
341 /// ## Args
342 ///
343 /// * `mat` - a slice of elements from a subfield of `F`.
344 /// * `vec` - a slice of `F` elements.
345 /// * `out` - a buffer for the output vector of `F` elements.
346 ///
347 /// ## Throws
348 ///
349 /// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
350 /// * Returns an error if `mat` is not a subfield of `F`.
351 fn fold_right(
352 &mut self,
353 mat: SubfieldSlice<'_, F, Self::DevMem>,
354 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
355 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
356 ) -> Result<(), Error>;
357
358 /// FRI-fold the interleaved codeword using the given challenges.
359 ///
360 /// The FRI-fold operation folds a length $2^{n+b+\eta}$ vector of field elements into a length
361 /// $2^n$ vector of field elements. $n$ is the log block length of the code, $b$ is the log
362 /// batch size, and $b + \eta$ is the number of challenge elements. The operation has the
363 /// following mathematical structure:
364 ///
365 /// 1. Split the challenge vector into two parts: $c_0$ with length $b$ and $c_1$ with length
366 /// $\eta$.
367 /// 2. Low fold the input data with the tensor expansion of $c_0.
368 /// 3. Apply $\eta$ layers of the inverse additive NTT to the data.
369 /// 4. Low fold the input data with the tensor expansion of $c_1.
370 ///
371 /// The algorithm to perform steps 3 and 4 can be combined into a linear amount of work,
372 /// whereas step 3 on its own would require $\eta$ independent linear passes.
373 ///
374 /// See [DP24], Section 4.2 for more details.
375 ///
376 /// This operation writes the result out-of-place into an output buffer.
377 ///
378 /// ## Arguments
379 ///
380 /// * `ntt` - the NTT instance, used to look up the twiddle values.
381 /// * `log_len` - $n + \eta$, the binary logarithm of the code length.
382 /// * `log_batch_size` - $b$, the binary logarithm of the interleaved code batch size.
383 /// * `challenges` - the folding challenges, with length $b + \eta$.
384 /// * `data_in` - an input vector, with length $2^{n + b + \eta}$.
385 /// * `data_out` - an output buffer, with length $2^n$.
386 ///
387 /// [DP24]: <https://eprint.iacr.org/2024/504>
388 #[allow(clippy::too_many_arguments)]
389 fn fri_fold<FSub>(
390 &mut self,
391 ntt: &(impl AdditiveNTT<FSub> + Sync),
392 log_len: usize,
393 log_batch_size: usize,
394 challenges: &[F],
395 data_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
396 data_out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
397 ) -> Result<(), Error>
398 where
399 FSub: BinaryField,
400 F: ExtensionField<FSub>;
401
402 /// Extrapolates a line between a vector of evaluations at 0 and evaluations at 1.
403 ///
404 /// Given two values $y_0, y_1$, this operation computes the value $y_z = y_0 + (y_1 - y_0) z$,
405 /// which is the value of the line that interpolates $(0, y_0), (1, y_1)$ at $z$. This computes
406 /// this operation in parallel over two vectors of big field elements of equal sizes.
407 ///
408 /// The operation writes the result back in-place into the `evals_0` buffer.
409 ///
410 /// ## Args
411 ///
412 /// * `evals_0` - this is both an input and output buffer. As in input, it is populated with the
413 /// values $y_0$, which are the line's values at 0.
414 /// * `evals_1` - an input buffer with the values $y_1$, which are the line's values at 1.
415 /// * `z` - the scalar evaluation point.
416 ///
417 /// ## Throws
418 ///
419 /// * if `evals_0` and `evals_1` are not equal sizes.
420 /// * if the sizes of `evals_0` and `evals_1` are not powers of two.
421 fn extrapolate_line(
422 &mut self,
423 evals_0: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
424 evals_1: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
425 z: F,
426 ) -> Result<(), Error>;
427
428 /// Computes the elementwise application of a compiled arithmetic expression to multiple input
429 /// slices.
430 ///
431 /// This operation applies the composition expression to each row of input values, where a row
432 /// consists of one element from each input slice at the same index position. The results are
433 /// stored in the output slice.
434 ///
435 /// ## Mathematical Definition
436 ///
437 /// Given:
438 /// - Multiple input slices $P_0, \ldots, P_{m-1}$, each of length $2^n$ elements
439 /// - A composition function $C(X_0, \ldots, X_{m-1})$
440 /// - An output slice $P_{\text{out}}$ of length $2^n$ elements
441 ///
442 /// This operation computes:
443 ///
444 /// $$
445 /// P_{\text{out}}\[i\] = C(P_0\[i\], \ldots, P_{m-1}\[i\])
446 /// \quad \forall i \in \{0, \ldots, 2^n- 1\}
447 /// $$
448 ///
449 /// ## Arguments
450 ///
451 /// * `inputs` - A slice of input slices, where each slice contains field elements.
452 /// * `output` - A mutable output slice where the results will be stored.
453 /// * `composition` - The compiled arithmetic expression to apply.
454 ///
455 /// ## Throws
456 ///
457 /// * Returns an error if any input or output slice has a length that is not a power of two.
458 /// * Returns an error if the input and output slices do not all have the same length.
459 fn compute_composite(
460 &mut self,
461 inputs: &SlicesBatch<<Self::DevMem as ComputeMemory<F>>::FSlice<'_>>,
462 output: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
463 composition: &Self::ExprEval,
464 ) -> Result<(), Error>;
465
466 /// Reduces a slice of elements to a single value by recursively applying pairwise
467 /// multiplication.
468 ///
469 /// Given an input slice `x` of length `n = 2^k` for some integer `k`,
470 /// this function computes the result:
471 ///
472 /// $$
473 /// y = \prod_{i=0}^{n-1} x_i
474 /// $$
475 ///
476 /// However, instead of a flat left-to-right reduction, the computation proceeds
477 /// in ⌈log₂(n)⌉ rounds, halving the number of elements each time:
478 ///
479 /// - Round 0: $$ x_{i,0} = x_{2i} \cdot x_{2i+1} \quad \text{for } i = 0 \ldots \frac{n}{2} - 1
480 /// $$
481 ///
482 /// - Round 1: $$ x_{i,1} = x_{2i,0} \cdot x_{2i+1,0} $$
483 ///
484 /// - ...
485 ///
486 /// - Final round: $$ y = x_{0,k} = \prod_{i=0}^{n-1} x_i $$
487 ///
488 /// This binary tree-style reduction is mathematically equivalent to the full product,
489 /// but structured for efficient parallelization.
490 ///
491 /// ## Arguments
492 ///
493 /// * `input`` - A slice of input field elements provided to the first reduction round
494 /// * `round_outputs` - A mutable slice of preallocated output field elements for each reduction
495 /// round. `round_outputs.len()` must equal log₂(input.len()) - 1. The length of the FSlice at
496 /// index i must equal input.len() / 2**(i + 1) for i in 0..round_outputs.len().
497 ///
498 /// ## Throws
499 ///
500 /// * Returns an error if the length of `input` is not a power of 2.
501 /// * Returns an error if the length of `input` is less than 2 (no reductions are possible).
502 /// * Returns an error if `round_outputs.len()` != log₂(input.len())
503 /// * Returns an error if any element in `round_outputs` does not satisfy
504 /// `round_outputs[i].len() == input.len() / 2**(i + 1)` for i in 0..round_outputs.len()
505 fn pairwise_product_reduce(
506 &mut self,
507 input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
508 round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
509 ) -> Result<(), Error>;
510}
511
512/// An interface for defining execution kernels.
513///
514/// A _kernel_ is a program that executes synchronously in one thread, with access to
515/// local memory buffers.
516///
517/// See [`ComputeLayerExecutor::accumulate_kernels`] for more information.
518pub trait KernelExecutor<F> {
519 /// The type for kernel-local memory buffers.
520 type Mem: ComputeMemory<F>;
521
522 /// The kernel(core)-level operation (scalar) type. This is a promise for a returned value.
523 type Value;
524
525 /// The evaluator for arithmetic expressions (polynomials).
526 type ExprEval: Sync;
527
528 /// Declares a kernel-level value.
529 fn decl_value(&mut self, init: F) -> Result<Self::Value, Error>;
530
531 /// A kernel-local operation that evaluates a composition polynomial over several buffers,
532 /// row-wise, and returns the sum of the evaluations, scaled by a batching coefficient.
533 ///
534 /// Mathematically, let there be $m$ input buffers, $P_0, \ldots, P_{m-1}$, each of length
535 /// $2^n$ elements. Let $c$ be the scaling coefficient (`batch_coeff`) and
536 /// $C(X_0, \ldots, X_{m-1})$ be the composition polynomial. The operation computes
537 ///
538 /// $$
539 /// \sum_{i=0}^{2^n - 1} c C(P_0\[i\], \ldots, P_{m-1}\[i\]).
540 /// $$
541 ///
542 /// The result is added back to an accumulator value.
543 ///
544 /// ## Arguments
545 ///
546 /// * `log_len` - the binary logarithm of the number of elements in each input buffer.
547 /// * `inputs` - the input buffers. Each row contains the values for a single variable.
548 /// * `composition` - the compiled composition polynomial expression. This is an output of
549 /// [`ComputeLayer::compile_expr`].
550 /// * `batch_coeff` - the scaling coefficient.
551 /// * `accumulator` - the output where the result is accumulated to.
552 fn sum_composition_evals(
553 &mut self,
554 inputs: &SlicesBatch<<Self::Mem as ComputeMemory<F>>::FSlice<'_>>,
555 composition: &Self::ExprEval,
556 batch_coeff: F,
557 accumulator: &mut Self::Value,
558 ) -> Result<(), Error>;
559
560 /// A kernel-local operation that performs point-wise addition of two input buffers into an
561 /// output buffer.
562 ///
563 /// ## Arguments
564 ///
565 /// * `log_len` - the binary logarithm of the number of elements in all three buffers.
566 /// * `src1` - the first input buffer.
567 /// * `src2` - the second input buffer.
568 /// * `dst` - the output buffer that receives the element-wise sum.
569 fn add(
570 &mut self,
571 log_len: usize,
572 src1: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
573 src2: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
574 dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
575 ) -> Result<(), Error>;
576
577 /// A kernel-local operation that adds a source buffer into a destination buffer, in place.
578 ///
579 /// ## Arguments
580 ///
581 /// * `log_len` - the binary logarithm of the number of elements in the two buffers.
582 /// * `src` - the source buffer.
583 /// * `dst` - the destination buffer.
584 fn add_assign(
585 &mut self,
586 log_len: usize,
587 src: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
588 dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
589 ) -> Result<(), Error>;
590}
591
592/// A memory mapping specification for a kernel execution.
593///
594/// See [`ComputeLayerExecutor::accumulate_kernels`] for context on kernel execution.
595pub enum KernelMemMap<'a, F, Mem: ComputeMemory<F>> {
596 /// This maps a chunk of a buffer in global device memory to a read-only kernel buffer.
597 Chunked {
598 data: Mem::FSlice<'a>,
599 log_min_chunk_size: usize,
600 },
601 /// This maps a chunk of a mutable buffer in global device memory to a read-write kernel
602 /// buffer. When the kernel exits, the data in the kernel buffer is written back to the
603 /// original location.
604 ChunkedMut {
605 data: Mem::FSliceMut<'a>,
606 log_min_chunk_size: usize,
607 },
608 /// This allocates a kernel-local scratchpad buffer. The size specified in the mapping is the
609 /// total size of all kernel scratchpads. This is so that the kernel's local scratchpad size
610 /// scales up proportionally to the size of chunked buffers.
611 Local { log_size: usize },
612}
613
614impl<'a, F, Mem: ComputeMemory<F>> KernelMemMap<'a, F, Mem> {
615 /// Computes a range of possible number of chunks that data can be split into, given a sequence
616 /// of memory mappings.
617 pub fn log_chunks_range(mappings: &[Self]) -> Option<Range<usize>> {
618 mappings
619 .iter()
620 .map(|mapping| match mapping {
621 Self::Chunked {
622 data,
623 log_min_chunk_size,
624 } => {
625 let log_data_size = checked_log_2(data.len());
626 let log_min_chunk_size = (*log_min_chunk_size)
627 .max(checked_log_2(Mem::ALIGNMENT))
628 .min(log_data_size);
629 0..(log_data_size - log_min_chunk_size)
630 }
631 Self::ChunkedMut {
632 data,
633 log_min_chunk_size,
634 } => {
635 let log_data_size = checked_log_2(data.len());
636 let log_min_chunk_size = (*log_min_chunk_size)
637 .max(checked_log_2(Mem::ALIGNMENT))
638 .min(log_data_size);
639 0..(log_data_size - log_min_chunk_size)
640 }
641 Self::Local { log_size } => 0..*log_size,
642 })
643 .reduce(|range0, range1| range0.start.max(range1.start)..range0.end.min(range1.end))
644 }
645
646 // Split the memory mapping into `1 << log_chunks>` chunks.
647 pub fn chunks(self, log_chunks: usize) -> impl Iterator<Item = KernelMemMap<'a, F, Mem>> {
648 match self {
649 Self::Chunked {
650 data,
651 log_min_chunk_size,
652 } => Either::Left(Either::Left(
653 Mem::slice_chunks(data, checked_int_div(data.len(), 1 << log_chunks)).map(
654 move |data| KernelMemMap::Chunked {
655 data,
656 log_min_chunk_size,
657 },
658 ),
659 )),
660 Self::ChunkedMut { data, .. } => {
661 let chunks_count = checked_int_div(data.len(), 1 << log_chunks);
662 Either::Left(Either::Right(Mem::slice_chunks_mut(data, chunks_count).map(
663 move |data| KernelMemMap::ChunkedMut {
664 data,
665 log_min_chunk_size: checked_log_2(chunks_count),
666 },
667 )))
668 }
669 Self::Local { log_size } => Either::Right(
670 std::iter::repeat_with(move || KernelMemMap::Local {
671 log_size: log_size - log_chunks,
672 })
673 .take(1 << log_chunks),
674 ),
675 }
676 }
677}
678
679/// A memory buffer mapped into a kernel.
680///
681/// See [`ComputeLayerExecutor::accumulate_kernels`] for context on kernel execution.
682pub enum KernelBuffer<'a, F, Mem: ComputeMemory<F>> {
683 Ref(Mem::FSlice<'a>),
684 Mut(Mem::FSliceMut<'a>),
685}
686
687impl<'a, F, Mem: ComputeMemory<F>> KernelBuffer<'a, F, Mem> {
688 /// Returns underlying data as an `FSlice`.
689 pub fn to_ref(&self) -> Mem::FSlice<'_> {
690 match self {
691 Self::Ref(slice) => Mem::narrow(slice),
692 Self::Mut(slice) => Mem::as_const(slice),
693 }
694 }
695}
696
697impl<'a, F, Mem: ComputeMemory<F>> SizedSlice for KernelBuffer<'a, F, Mem> {
698 fn len(&self) -> usize {
699 match self {
700 KernelBuffer::Ref(mem) => mem.len(),
701 KernelBuffer::Mut(mem) => mem.len(),
702 }
703 }
704}
705
706#[derive(Debug, thiserror::Error)]
707pub enum Error {
708 #[error("input validation: {0}")]
709 InputValidation(String),
710 #[error("allocation error: {0}")]
711 Alloc(#[from] AllocError),
712 #[error("device error: {0}")]
713 DeviceError(Box<dyn std::error::Error + Send + Sync + 'static>),
714 #[error("core library error: {0}")]
715 CoreLibError(Box<dyn std::error::Error + Send + Sync + 'static>),
716}
717
718// Convenience types for the device memory.
719pub type FSlice<'a, F, HAL> = <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSlice<'a>;
720pub type FSliceMut<'a, F, HAL> =
721 <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSliceMut<'a>;
722
723pub type KernelMem<'a, F, HAL> = <<<HAL as ComputeLayer<F>>::Exec<'a> as ComputeLayerExecutor<F>>::KernelExec as KernelExecutor<F>>::Mem;
724pub type KernelSlice<'a, 'b, F, HAL> = <KernelMem<'b, F, HAL> as ComputeMemory<F>>::FSlice<'a>;
725pub type KernelSliceMut<'a, 'b, F, HAL> =
726 <KernelMem<'b, F, HAL> as ComputeMemory<F>>::FSliceMut<'a>;
727
728/// This is a trait for a holder type for the popular triple:
729/// * a compute layer (HAL),
730/// * a host memory allocator,
731/// * a device memory allocator.
732pub trait ComputeHolder<F: Field, HAL: ComputeLayer<F>> {
733 type HostComputeAllocator<'a>: ComputeAllocator<F, CpuMemory>
734 where
735 Self: 'a;
736 type DeviceComputeAllocator<'a>: ComputeAllocator<F, HAL::DevMem>
737 where
738 Self: 'a;
739
740 fn to_data<'a, 'b>(
741 &'a mut self,
742 ) -> ComputeData<'a, F, HAL, Self::HostComputeAllocator<'b>, Self::DeviceComputeAllocator<'b>>
743 where
744 'a: 'b;
745}
746
747pub struct ComputeData<'a, F: Field, HAL: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>
748where
749 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
750 DeviceAllocatorType: ComputeAllocator<F, HAL::DevMem>,
751{
752 pub hal: &'a HAL,
753 pub host_alloc: HostAllocatorType,
754 pub dev_alloc: DeviceAllocatorType,
755 _phantom_data: PhantomData<F>,
756}
757
758impl<'a, F: Field, HAL: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>
759 ComputeData<'a, F, HAL, HostAllocatorType, DeviceAllocatorType>
760where
761 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
762 DeviceAllocatorType: ComputeAllocator<F, HAL::DevMem>,
763{
764 pub fn new(
765 hal: &'a HAL,
766 host_alloc: HostAllocatorType,
767 dev_alloc: DeviceAllocatorType,
768 ) -> Self {
769 Self {
770 hal,
771 host_alloc,
772 dev_alloc,
773 _phantom_data: PhantomData::<F>,
774 }
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use binius_field::{BinaryField128b, Field, TowerField};
781 use binius_math::B128;
782 use rand::{SeedableRng, prelude::StdRng};
783
784 use super::*;
785 use crate::{
786 alloc::ComputeAllocator,
787 cpu::{CpuMemory, layer::CpuLayerHolder},
788 };
789
790 /// Test showing how to allocate host memory and create a sub-allocator over it.
791 // TODO: This 'a lifetime bound on HAL is pretty annoying. I'd like to get rid of it.
792 fn test_copy_host_device<'a, F: TowerField, HAL: ComputeLayer<F> + 'a>(
793 mut compute_holder: impl ComputeHolder<F, HAL>,
794 ) {
795 let ComputeData {
796 hal,
797 host_alloc,
798 dev_alloc,
799 _phantom_data,
800 } = compute_holder.to_data();
801
802 let mut rng = StdRng::seed_from_u64(0);
803
804 let host_buf_1 = host_alloc.alloc(128).unwrap();
805 let host_buf_2 = host_alloc.alloc(128).unwrap();
806 let mut dev_buf_1 = dev_alloc.alloc(128).unwrap();
807 let mut dev_buf_2 = dev_alloc.alloc(128).unwrap();
808
809 for elem in host_buf_1.iter_mut() {
810 *elem = F::random(&mut rng);
811 }
812
813 hal.copy_h2d(host_buf_1, &mut dev_buf_1).unwrap();
814 hal.copy_d2d(HAL::DevMem::as_const(&dev_buf_1), &mut dev_buf_2)
815 .unwrap();
816 hal.copy_d2h(HAL::DevMem::as_const(&dev_buf_2), host_buf_2)
817 .unwrap();
818
819 assert_eq!(host_buf_1, host_buf_2);
820 }
821
822 #[test]
823 fn test_cpu_copy_host_device() {
824 test_copy_host_device(CpuLayerHolder::<B128>::new(512, 256));
825 }
826
827 #[test]
828 fn test_log_chunks_range() {
829 let mem_1 = vec![BinaryField128b::ZERO; 256];
830 let mut mem_2 = vec![BinaryField128b::ZERO; 256];
831
832 let mappings = vec![
833 KernelMemMap::Chunked {
834 data: mem_1.as_slice(),
835 log_min_chunk_size: 4,
836 },
837 KernelMemMap::ChunkedMut {
838 data: mem_2.as_mut_slice(),
839 log_min_chunk_size: 6,
840 },
841 KernelMemMap::Local { log_size: 8 },
842 ];
843
844 let range =
845 KernelMemMap::<BinaryField128b, CpuMemory>::log_chunks_range(&mappings).unwrap();
846 assert_eq!(range.start, 0);
847 assert_eq!(range.end, 2);
848 }
849}