binius_compute/layer.rs
1// Copyright 2025 Irreducible Inc.
2
3use std::ops::Range;
4
5use binius_field::{BinaryField, ExtensionField, Field};
6use binius_math::ArithCircuit;
7use binius_ntt::AdditiveNTT;
8use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
9use itertools::Either;
10
11use super::{
12 alloc::Error as AllocError,
13 memory::{ComputeMemory, SubfieldSlice},
14};
15use crate::memory::{SizedSlice, SlicesBatch};
16
17/// A hardware abstraction layer (HAL) for compute operations.
18pub trait ComputeLayer<F: Field>: 'static + Sync {
19 /// The device memory.
20 type DevMem: ComputeMemory<F>;
21
22 /// The executor that can execute operations on the device.
23 type Exec;
24
25 /// The executor that can execute operations on a kernel-level granularity (i.e., a single
26 /// core).
27 type KernelExec: KernelExecutor<F, Mem = Self::DevMem, ExprEval = Self::ExprEval>;
28
29 /// The operation (scalar) value type.
30 type OpValue;
31
32 /// The evaluator for arithmetic expressions (polynomials).
33 type ExprEval: Sync;
34
35 /// Allocates a slice of memory on the host that is prepared for transfers to/from the device.
36 ///
37 /// Depending on the compute layer, this may perform steps beyond just allocating memory. For
38 /// example, it may allocate huge pages or map the allocated memory to the IOMMU.
39 ///
40 /// The returned buffer is lifetime bound to the compute layer, allowing return types to have
41 /// drop methods referencing data in the compute layer.
42 fn host_alloc(&self, n: usize) -> impl AsMut<[F]> + '_;
43
44 /// Copy data from the host to the device.
45 ///
46 /// ## Preconditions
47 ///
48 /// * `src` and `dst` must have the same length.
49 /// * `src` must be a slice of a buffer returned by [`Self::host_alloc`].
50 fn copy_h2d(&self, src: &[F], dst: &mut FSliceMut<'_, F, Self>) -> Result<(), Error>;
51
52 /// Copy data from the device to the host.
53 ///
54 /// ## Preconditions
55 ///
56 /// * `src` and `dst` must have the same length.
57 /// * `dst` must be a slice of a buffer returned by [`Self::host_alloc`].
58 fn copy_d2h(&self, src: FSlice<'_, F, Self>, dst: &mut [F]) -> Result<(), Error>;
59
60 /// Copy data between disjoint device buffers.
61 ///
62 /// ## Preconditions
63 ///
64 /// * `src` and `dst` must have the same length.
65 fn copy_d2d(
66 &self,
67 src: FSlice<'_, F, Self>,
68 dst: &mut FSliceMut<'_, F, Self>,
69 ) -> Result<(), Error>;
70
71 /// Executes an operation.
72 ///
73 /// A HAL operation is an abstract function that runs with an executor reference.
74 fn execute(
75 &self,
76 f: impl FnOnce(&mut Self::Exec) -> Result<Vec<Self::OpValue>, Error>,
77 ) -> Result<Vec<F>, Error>;
78
79 /// Creates an operation that depends on the concurrent execution of two inner operations.
80 fn join<Out1, Out2>(
81 &self,
82 exec: &mut Self::Exec,
83 op1: impl FnOnce(&mut Self::Exec) -> Result<Out1, Error>,
84 op2: impl FnOnce(&mut Self::Exec) -> Result<Out2, Error>,
85 ) -> Result<(Out1, Out2), Error> {
86 let out1 = op1(exec)?;
87 let out2 = op2(exec)?;
88 Ok((out1, out2))
89 }
90
91 /// Creates an operation that depends on the concurrent execution of a sequence of operations.
92 fn map<Out, I: ExactSizeIterator>(
93 &self,
94 exec: &mut Self::Exec,
95 iter: I,
96 map: impl Fn(&mut Self::Exec, I::Item) -> Result<Out, Error>,
97 ) -> Result<Vec<Out>, Error> {
98 iter.map(|item| map(exec, item)).collect()
99 }
100
101 /// Compiles an arithmetic expression to the evaluator.
102 fn compile_expr(&self, expr: &ArithCircuit<F>) -> Result<Self::ExprEval, Error>;
103
104 /// Launch many kernels in parallel and accumulate the scalar results with field addition.
105 ///
106 /// This method provides low-level access to schedule parallel kernel executions on the compute
107 /// platform. A _kernel_ is a program that executes synchronously in one thread, with access to
108 /// local memory buffers. When the environment launches a kernel, it sets up the kernel's local
109 /// memory according to the memory mapping specifications provided by the `mem_maps` parameter.
110 /// The mapped buffers have type [`KernelBuffer`], and they may be read-write or read-only.
111 /// When the kernel exits, it returns a small number of computed values as field elements. The
112 /// vector of returned scalars is accumulated via binary field addition across all kernels and
113 /// returned from the call.
114 ///
115 /// This method is fairly general but also designed to fit the specific needs of the sumcheck
116 /// protocol. That motivates the choice that returned values are small-length vectors that are
117 /// accumulated with field addition.
118 ///
119 /// ## Buffer chunking
120 ///
121 /// The kernel local memory buffers are thought of as slices of a larger buffer, which may or
122 /// may not exist somewhere else. Each kernel operates on a chunk of the larger buffers. For
123 /// example, the [`KernelMemMap::Chunked`] mapping specifies that each kernel operates on a
124 /// read-only chunk of a buffer in global device memory. The [`KernelMemMap::Local`] mapping
125 /// specifies that a kernel operates on a local scratchpad initialized with zeros and discarded
126 /// at the end of kernel execution (sort of like /dev/null).
127 ///
128 /// This [`ComputeLayer`] object can decide how many kernels to launch and thus how large
129 /// each kernel's buffer chunks are. The number of chunks must be a power of two. This
130 /// information is provided to the kernel specification closure as an argument.
131 ///
132 /// ## Kernel specification
133 ///
134 /// The kernel logic is constructed within a closure, which is the `map` parameter. The closure
135 /// has three parameters:
136 ///
137 /// * `kernel_exec` - the kernel execution environment.
138 /// * `log_chunks` - the binary logarithm of the number of kernels that are launched.
139 /// * `buffers` - a vector of kernel-local buffers.
140 ///
141 /// The closure must respect certain assumptions:
142 ///
143 /// * The kernel closure control flow is identical on each invocation when `log_chunks` is
144 /// unchanged.
145 ///
146 /// [`ComputeLayer`] implementations are free to call the specification closure multiple times,
147 /// for example with different values for `log_chunks`.
148 ///
149 /// ## Arguments
150 ///
151 /// * `map` - the kernel specification closure. See the "Kernel specification" section above.
152 /// * `mem_maps` - the memory mappings for the kernel-local buffers.
153 fn accumulate_kernels(
154 &self,
155 exec: &mut Self::Exec,
156 map: impl Sync
157 + for<'a> Fn(
158 &'a mut Self::KernelExec,
159 usize,
160 Vec<KernelBuffer<'a, F, Self::DevMem>>,
161 ) -> Result<Vec<<Self::KernelExec as KernelExecutor<F>>::Value>, Error>,
162 mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
163 ) -> Result<Vec<Self::OpValue>, Error>;
164
165 /// Returns the inner product of a vector of subfield elements with big field elements.
166 ///
167 /// ## Arguments
168 ///
169 /// * `a_in` - the first input slice of subfield elements.
170 /// * `b_in` - the second input slice of `F` elements.
171 ///
172 /// ## Throws
173 ///
174 /// * if `tower_level` or `a_in` is greater than `F::TOWER_LEVEL`
175 /// * unless `a_in` and `b_in` contain the same number of elements, and the number is a power of
176 /// two
177 ///
178 /// ## Returns
179 ///
180 /// Returns the inner product of `a_in` and `b_in`.
181 fn inner_product(
182 &self,
183 exec: &mut Self::Exec,
184 a_in: SubfieldSlice<'_, F, Self::DevMem>,
185 b_in: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
186 ) -> Result<Self::OpValue, Error>;
187
188 /// Computes the iterative tensor product of the input with the given coordinates.
189 ///
190 /// This operation modifies the data buffer in place.
191 ///
192 /// ## Mathematical Definition
193 ///
194 /// This operation accepts parameters
195 ///
196 /// * $n \in \mathbb{N}$ (`log_n`),
197 /// * $k \in \mathbb{N}$ (`coordinates.len()`),
198 /// * $v \in L^{2^n}$ (`data[..1 << log_n]`),
199 /// * $r \in L^k$ (`coordinates`),
200 ///
201 /// and computes the vector
202 ///
203 /// $$
204 /// v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})
205 /// $$
206 ///
207 /// ## Throws
208 ///
209 /// * unless `2**(log_n + coordinates.len())` equals `data.len()`
210 fn tensor_expand(
211 &self,
212 exec: &mut Self::Exec,
213 log_n: usize,
214 coordinates: &[F],
215 data: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
216 ) -> Result<(), Error>;
217
218 /// Computes left matrix-vector multiplication of a subfield matrix with a big field vector.
219 ///
220 /// ## Mathematical Definition
221 ///
222 /// This operation accepts
223 ///
224 /// * $n \in \mathbb{N}$ (`out.len()`),
225 /// * $m \in \mathbb{N}$ (`vec.len()`),
226 /// * $M \in K^{n \times m}$ (`mat`),
227 /// * $v \in K^m$ (`vec`),
228 ///
229 /// and computes the vector $Mv$.
230 ///
231 /// ## Args
232 ///
233 /// * `mat` - a slice of elements from a subfield of `F`.
234 /// * `vec` - a slice of `F` elements.
235 /// * `out` - a buffer for the output vector of `F` elements.
236 ///
237 /// ## Throws
238 ///
239 /// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
240 /// * Returns an error if `mat` is not a subfield of `F`.
241 fn fold_left<'a>(
242 &'a self,
243 exec: &'a mut Self::Exec,
244 mat: SubfieldSlice<'_, F, Self::DevMem>,
245 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
246 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
247 ) -> Result<(), Error>;
248
249 /// Computes right matrix-vector multiplication of a subfield matrix with a big field vector.
250 ///
251 /// ## Mathematical Definition
252 ///
253 /// This operation accepts
254 ///
255 /// * $n \in \mathbb{N}$ (`vec.len()`),
256 /// * $m \in \mathbb{N}$ (`out.len()`),
257 /// * $M \in K^{n \times m}$ (`mat`),
258 /// * $v \in K^m$ (`vec`),
259 ///
260 /// and computes the vector $((v')M)'$. The prime denotes a transpose
261 ///
262 /// ## Args
263 ///
264 /// * `mat` - a slice of elements from a subfield of `F`.
265 /// * `vec` - a slice of `F` elements.
266 /// * `out` - a buffer for the output vector of `F` elements.
267 ///
268 /// ## Throws
269 ///
270 /// * Returns an error if `mat.len()` does not equal `vec.len() * out.len()`.
271 /// * Returns an error if `mat` is not a subfield of `F`.
272 fn fold_right<'a>(
273 &'a self,
274 exec: &'a mut Self::Exec,
275 mat: SubfieldSlice<'_, F, Self::DevMem>,
276 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
277 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
278 ) -> Result<(), Error>;
279
280 /// FRI-fold the interleaved codeword using the given challenges.
281 ///
282 /// The FRI-fold operation folds a length $2^{n+b+\eta}$ vector of field elements into a length
283 /// $2^n$ vector of field elements. $n$ is the log block length of the code, $b$ is the log
284 /// batch size, and $b + \eta$ is the number of challenge elements. The operation has the
285 /// following mathematical structure:
286 ///
287 /// 1. Split the challenge vector into two parts: $c_0$ with length $b$ and $c_1$ with length
288 /// $\eta$.
289 /// 2. Low fold the input data with the tensor expansion of $c_0.
290 /// 3. Apply $\eta$ layers of the inverse additive NTT to the data.
291 /// 4. Low fold the input data with the tensor expansion of $c_1.
292 ///
293 /// The algorithm to perform steps 3 and 4 can be combined into a linear amount of work,
294 /// whereas step 3 on its own would require $\eta$ independent linear passes.
295 ///
296 /// See [DP24], Section 4.2 for more details.
297 ///
298 /// This operation writes the result out-of-place into an output buffer.
299 ///
300 /// ## Arguments
301 ///
302 /// * `ntt` - the NTT instance, used to look up the twiddle values.
303 /// * `log_len` - $n + \eta$, the binary logarithm of the code length.
304 /// * `log_batch_size` - $b$, the binary logarithm of the interleaved code batch size.
305 /// * `challenges` - the folding challenges, with length $b + \eta$.
306 /// * `data_in` - an input vector, with length $2^{n + b + \eta}$.
307 /// * `data_out` - an output buffer, with length $2^n$.
308 ///
309 /// [DP24]: <https://eprint.iacr.org/2024/504>
310 #[allow(clippy::too_many_arguments)]
311 fn fri_fold<FSub>(
312 &self,
313 exec: &mut Self::Exec,
314 ntt: &(impl AdditiveNTT<FSub> + Sync),
315 log_len: usize,
316 log_batch_size: usize,
317 challenges: &[F],
318 data_in: FSlice<F, Self>,
319 data_out: &mut FSliceMut<F, Self>,
320 ) -> Result<(), Error>
321 where
322 FSub: BinaryField,
323 F: ExtensionField<FSub>;
324
325 /// Extrapolates a line between a vector of evaluations at 0 and evaluations at 1.
326 ///
327 /// Given two values $y_0, y_1$, this operation computes the value $y_z = y_0 + (y_1 - y_0) z$,
328 /// which is the value of the line that interpolates $(0, y_0), (1, y_1)$ at $z$. This computes
329 /// this operation in parallel over two vectors of big field elements of equal sizes.
330 ///
331 /// The operation writes the result back in-place into the `evals_0` buffer.
332 ///
333 /// ## Args
334 ///
335 /// * `evals_0` - this is both an input and output buffer. As in input, it is populated with the
336 /// values $y_0$, which are the line's values at 0.
337 /// * `evals_1` - an input buffer with the values $y_1$, which are the line's values at 1.
338 /// * `z` - the scalar evaluation point.
339 ///
340 /// ## Throws
341 ///
342 /// * if `evals_0` and `evals_1` are not equal sizes.
343 /// * if the sizes of `evals_0` and `evals_1` are not powers of two.
344 fn extrapolate_line(
345 &self,
346 exec: &mut Self::Exec,
347 evals_0: &mut FSliceMut<F, Self>,
348 evals_1: FSlice<F, Self>,
349 z: F,
350 ) -> Result<(), Error>;
351
352 /// Computes the elementwise application of a compiled arithmetic expression to multiple input
353 /// slices.
354 ///
355 /// This operation applies the composition expression to each row of input values, where a row
356 /// consists of one element from each input slice at the same index position. The results are
357 /// stored in the output slice.
358 ///
359 /// ## Mathematical Definition
360 ///
361 /// Given:
362 /// - Multiple input slices $P_0, \ldots, P_{m-1}$, each of length $2^n$ elements
363 /// - A composition function $C(X_0, \ldots, X_{m-1})$
364 /// - An output slice $P_{\text{out}}$ of length $2^n$ elements
365 ///
366 /// This operation computes:
367 ///
368 /// $$
369 /// P_{\text{out}}\[i\] = C(P_0\[i\], \ldots, P_{m-1}\[i\])
370 /// \quad \forall i \in \{0, \ldots, 2^n- 1\}
371 /// $$
372 ///
373 /// ## Arguments
374 ///
375 /// * `exec` - The execution environment.
376 /// * `inputs` - A slice of input slices, where each slice contains field elements.
377 /// * `output` - A mutable output slice where the results will be stored.
378 /// * `composition` - The compiled arithmetic expression to apply.
379 ///
380 /// ## Throws
381 ///
382 /// * Returns an error if any input or output slice has a length that is not a power of two.
383 /// * Returns an error if the input and output slices do not all have the same length.
384 fn compute_composite(
385 &self,
386 exec: &mut Self::Exec,
387 inputs: &SlicesBatch<FSlice<'_, F, Self>>,
388 output: &mut FSliceMut<'_, F, Self>,
389 composition: &Self::ExprEval,
390 ) -> Result<(), Error>;
391}
392
393/// An interface for defining execution kernels.
394///
395/// A _kernel_ is a program that executes synchronously in one thread, with access to
396/// local memory buffers.
397///
398/// See [`ComputeLayer::accumulate_kernels`] for more information.
399pub trait KernelExecutor<F> {
400 /// The type for kernel-local memory buffers.
401 type Mem: ComputeMemory<F>;
402
403 /// The kernel(core)-level operation (scalar) type. This is a promise for a returned value.
404 type Value;
405
406 /// The evaluator for arithmetic expressions (polynomials).
407 type ExprEval: Sync;
408
409 /// Declares a kernel-level value.
410 fn decl_value(&mut self, init: F) -> Result<Self::Value, Error>;
411
412 /// A kernel-local operation that evaluates a composition polynomial over several buffers,
413 /// row-wise, and returns the sum of the evaluations, scaled by a batching coefficient.
414 ///
415 /// Mathematically, let there be $m$ input buffers, $P_0, \ldots, P_{m-1}$, each of length
416 /// $2^n$ elements. Let $c$ be the scaling coefficient (`batch_coeff`) and
417 /// $C(X_0, \ldots, X_{m-1})$ be the composition polynomial. The operation computes
418 ///
419 /// $$
420 /// \sum_{i=0}^{2^n - 1} c C(P_0\[i\], \ldots, P_{m-1}\[i\]).
421 /// $$
422 ///
423 /// The result is added back to an accumulator value.
424 ///
425 /// ## Arguments
426 ///
427 /// * `log_len` - the binary logarithm of the number of elements in each input buffer.
428 /// * `inputs` - the input buffers. Each row contains the values for a single variable.
429 /// * `composition` - the compiled composition polynomial expression. This is an output of
430 /// [`ComputeLayer::compile_expr`].
431 /// * `batch_coeff` - the scaling coefficient.
432 /// * `accumulator` - the output where the result is accumulated to.
433 fn sum_composition_evals(
434 &mut self,
435 inputs: &SlicesBatch<<Self::Mem as ComputeMemory<F>>::FSlice<'_>>,
436 composition: &Self::ExprEval,
437 batch_coeff: F,
438 accumulator: &mut Self::Value,
439 ) -> Result<(), Error>;
440
441 /// A kernel-local operation that performs point-wise addition of two input buffers into an
442 /// output buffer.
443 ///
444 /// ## Arguments
445 ///
446 /// * `log_len` - the binary logarithm of the number of elements in all three buffers.
447 /// * `src1` - the first input buffer.
448 /// * `src2` - the second input buffer.
449 /// * `dst` - the output buffer that receives the element-wise sum.
450 fn add(
451 &mut self,
452 log_len: usize,
453 src1: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
454 src2: <Self::Mem as ComputeMemory<F>>::FSlice<'_>,
455 dst: &mut <Self::Mem as ComputeMemory<F>>::FSliceMut<'_>,
456 ) -> Result<(), Error>;
457}
458
459/// A memory mapping specification for a kernel execution.
460///
461/// See [`ComputeLayer::accumulate_kernels`] for context on kernel execution.
462pub enum KernelMemMap<'a, F, Mem: ComputeMemory<F>> {
463 /// This maps a chunk of a buffer in global device memory to a read-only kernel buffer.
464 Chunked {
465 data: Mem::FSlice<'a>,
466 log_min_chunk_size: usize,
467 },
468 /// This maps a chunk of a mutable buffer in global device memory to a read-write kernel
469 /// buffer. When the kernel exits, the data in the kernel buffer is written back to the
470 /// original location.
471 ChunkedMut {
472 data: Mem::FSliceMut<'a>,
473 log_min_chunk_size: usize,
474 },
475 /// This allocates a kernel-local scratchpad buffer. The size specified in the mapping is the
476 /// total size of all kernel scratchpads. This is so that the kernel's local scratchpad size
477 /// scales up proportionally to the size of chunked buffers.
478 Local { log_size: usize },
479}
480
481impl<'a, F, Mem: ComputeMemory<F>> KernelMemMap<'a, F, Mem> {
482 /// Computes a range of possible number of chunks that data can be split into, given a sequence
483 /// of memory mappings.
484 pub fn log_chunks_range(mappings: &[Self]) -> Option<Range<usize>> {
485 mappings
486 .iter()
487 .map(|mapping| match mapping {
488 Self::Chunked {
489 data,
490 log_min_chunk_size,
491 } => {
492 let log_data_size = checked_log_2(data.len());
493 let log_min_chunk_size = (*log_min_chunk_size)
494 .max(checked_log_2(Mem::ALIGNMENT))
495 .min(log_data_size);
496 0..(log_data_size - log_min_chunk_size)
497 }
498 Self::ChunkedMut {
499 data,
500 log_min_chunk_size,
501 } => {
502 let log_data_size = checked_log_2(data.len());
503 let log_min_chunk_size = (*log_min_chunk_size)
504 .max(checked_log_2(Mem::ALIGNMENT))
505 .min(log_data_size);
506 0..(log_data_size - log_min_chunk_size)
507 }
508 Self::Local { log_size } => 0..*log_size,
509 })
510 .reduce(|range0, range1| range0.start.max(range1.start)..range0.end.min(range1.end))
511 }
512
513 // Split the memory mapping into `1 << log_chunks>` chunks.
514 pub fn chunks(self, log_chunks: usize) -> impl Iterator<Item = KernelMemMap<'a, F, Mem>> {
515 match self {
516 Self::Chunked {
517 data,
518 log_min_chunk_size,
519 } => Either::Left(Either::Left(
520 Mem::slice_chunks(data, checked_int_div(data.len(), 1 << log_chunks)).map(
521 move |data| KernelMemMap::Chunked {
522 data,
523 log_min_chunk_size,
524 },
525 ),
526 )),
527 Self::ChunkedMut { data, .. } => {
528 let chunks_count = checked_int_div(data.len(), 1 << log_chunks);
529 Either::Left(Either::Right(Mem::slice_chunks_mut(data, chunks_count).map(
530 move |data| KernelMemMap::ChunkedMut {
531 data,
532 log_min_chunk_size: checked_log_2(chunks_count),
533 },
534 )))
535 }
536 Self::Local { log_size } => Either::Right(
537 std::iter::repeat_with(move || KernelMemMap::Local {
538 log_size: log_size - log_chunks,
539 })
540 .take(1 << log_chunks),
541 ),
542 }
543 }
544}
545
546/// A memory buffer mapped into a kernel.
547///
548/// See [`ComputeLayer::accumulate_kernels`] for context on kernel execution.
549pub enum KernelBuffer<'a, F, Mem: ComputeMemory<F>> {
550 Ref(Mem::FSlice<'a>),
551 Mut(Mem::FSliceMut<'a>),
552}
553
554impl<'a, F, Mem: ComputeMemory<F>> KernelBuffer<'a, F, Mem> {
555 /// Returns underlying data as an `FSlice`.
556 pub fn to_ref(&self) -> Mem::FSlice<'_> {
557 match self {
558 Self::Ref(slice) => Mem::narrow(slice),
559 Self::Mut(slice) => Mem::as_const(slice),
560 }
561 }
562}
563
564impl<'a, F, Mem: ComputeMemory<F>> SizedSlice for KernelBuffer<'a, F, Mem> {
565 fn len(&self) -> usize {
566 match self {
567 KernelBuffer::Ref(mem) => mem.len(),
568 KernelBuffer::Mut(mem) => mem.len(),
569 }
570 }
571}
572
573#[derive(Debug, thiserror::Error)]
574pub enum Error {
575 #[error("input validation: {0}")]
576 InputValidation(String),
577 #[error("allocation error: {0}")]
578 Alloc(#[from] AllocError),
579 #[error("device error: {0}")]
580 DeviceError(Box<dyn std::error::Error + Send + Sync + 'static>),
581}
582
583// Convenience types for the device memory.
584pub type FSlice<'a, F, HAL> = <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSlice<'a>;
585pub type FSliceMut<'a, F, HAL> =
586 <<HAL as ComputeLayer<F>>::DevMem as ComputeMemory<F>>::FSliceMut<'a>;
587
588#[cfg(test)]
589mod tests {
590 use assert_matches::assert_matches;
591 use binius_field::{BinaryField128b, Field, TowerField, tower::CanonicalTowerFamily};
592 use rand::{SeedableRng, prelude::StdRng};
593
594 use super::*;
595 use crate::{
596 alloc::{BumpAllocator, ComputeAllocator, Error as AllocError, HostBumpAllocator},
597 cpu::{CpuLayer, CpuMemory},
598 };
599
600 /// Test showing how to allocate host memory and create a sub-allocator over it.
601 fn test_host_alloc<F: TowerField, HAL: ComputeLayer<F>>(hal: HAL) {
602 let mut host_slice = hal.host_alloc(256);
603
604 let bump = HostBumpAllocator::new(host_slice.as_mut());
605 assert_eq!(bump.alloc(100).unwrap().len(), 100);
606 assert_eq!(bump.alloc(100).unwrap().len(), 100);
607 assert_matches!(bump.alloc(100), Err(AllocError::OutOfMemory));
608 }
609
610 /// Test showing how to allocate host memory and create a sub-allocator over it.
611 // TODO: This 'a lifetime bound on HAL is pretty annoying. I'd like to get rid of it.
612 fn test_copy_host_device<'a, F: TowerField, HAL: ComputeLayer<F> + 'a>(
613 hal: HAL,
614 mut dev_mem: FSliceMut<'a, F, HAL>,
615 ) {
616 let mut rng = StdRng::seed_from_u64(0);
617
618 let mut host_slice = hal.host_alloc(256);
619
620 let host_alloc = HostBumpAllocator::new(host_slice.as_mut());
621 let dev_alloc = BumpAllocator::<F, HAL::DevMem>::from_ref(&mut dev_mem);
622
623 let host_buf_1 = host_alloc.alloc(128).unwrap();
624 let host_buf_2 = host_alloc.alloc(128).unwrap();
625 let mut dev_buf_1 = dev_alloc.alloc(128).unwrap();
626 let mut dev_buf_2 = dev_alloc.alloc(128).unwrap();
627
628 for elem in &mut *host_buf_1 {
629 *elem = F::random(&mut rng);
630 }
631
632 hal.copy_h2d(host_buf_1, &mut dev_buf_1).unwrap();
633 hal.copy_d2d(HAL::DevMem::as_const(&dev_buf_1), &mut dev_buf_2)
634 .unwrap();
635 hal.copy_d2h(HAL::DevMem::as_const(&dev_buf_2), host_buf_2)
636 .unwrap();
637
638 assert_eq!(host_buf_1, host_buf_2);
639 }
640
641 #[test]
642 fn test_cpu_host_alloc() {
643 test_host_alloc(CpuLayer::<CanonicalTowerFamily>::default());
644 }
645
646 #[test]
647 fn test_cpu_copy_host_device() {
648 let mut dev_mem = vec![BinaryField128b::ZERO; 256];
649 test_copy_host_device(CpuLayer::<CanonicalTowerFamily>::default(), dev_mem.as_mut_slice());
650 }
651
652 #[test]
653 fn test_log_chunks_range() {
654 let mem_1 = vec![BinaryField128b::ZERO; 256];
655 let mut mem_2 = vec![BinaryField128b::ZERO; 256];
656
657 let mappings = vec![
658 KernelMemMap::Chunked {
659 data: mem_1.as_slice(),
660 log_min_chunk_size: 4,
661 },
662 KernelMemMap::ChunkedMut {
663 data: mem_2.as_mut_slice(),
664 log_min_chunk_size: 6,
665 },
666 KernelMemMap::Local { log_size: 8 },
667 ];
668
669 let range =
670 KernelMemMap::<BinaryField128b, CpuMemory>::log_chunks_range(&mappings).unwrap();
671 assert_eq!(range.start, 0);
672 assert_eq!(range.end, 2);
673 }
674}