1use std::{
4 any::TypeId,
5 cell::RefCell,
6 iter::zip,
7 marker::PhantomData,
8 mem::{MaybeUninit, transmute},
9 slice,
10};
11
12use binius_compute::{
13 ComputeData, ComputeHolder, ComputeLayerExecutor, KernelExecutor,
14 alloc::{BumpAllocator, ComputeAllocator, HostBumpAllocator},
15 cpu::layer::count_total_local_buffer_sizes,
16 each_generic_tower_subfield as each_tower_subfield,
17 layer::{ComputeLayer, Error, FSlice, FSliceMut, KernelBuffer, KernelMemMap},
18 memory::{ComputeMemory, SizedSlice, SlicesBatch, SubfieldSlice},
19};
20use binius_field::{
21 AESTowerField8b, AESTowerField128b, BinaryField8b, BinaryField128b, ByteSlicedUnderlier,
22 ExtensionField, Field, PackedBinaryField1x128b, PackedBinaryField2x128b,
23 PackedBinaryField4x128b, PackedExtension, PackedField,
24 as_packed_field::{PackScalar, PackedType},
25 linear_transformation::{PackedTransformationFactory, Transformation},
26 make_aes_to_binary_packed_transformer, make_binary_to_aes_packed_transformer,
27 tower::{PackedTop, TowerFamily},
28 tower_levels::TowerLevel16,
29 underlier::{NumCast, UnderlierWithBitOps, WithUnderlier},
30 unpack_if_possible, unpack_if_possible_mut,
31 util::inner_product_par,
32};
33use binius_math::{ArithCircuit, CompositionPoly, RowsBatchRef, tensor_prod_eq_ind};
34use binius_maybe_rayon::{
35 iter::{
36 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
37 ParallelIterator,
38 },
39 prelude::ParallelBridge,
40 slice::{ParallelSlice, ParallelSliceMut},
41};
42use binius_ntt::{AdditiveNTT, fri::fold_interleaved_allocated};
43use binius_utils::{
44 checked_arithmetics::{checked_int_div, strict_log_2},
45 rayon::get_log_max_threads,
46 strided_array::StridedArray2DViewMut,
47};
48use bytemuck::{Pod, zeroed_vec};
49use itertools::izip;
50use thread_local::ThreadLocal;
51
52use crate::{
53 arith_circuit::ArithCircuitPoly,
54 memory::{PackedMemory, PackedMemorySliceMut},
55};
56
57#[derive(Debug)]
59pub struct FastCpuLayer<T: TowerFamily, P: PackedTop<T>> {
60 kernel_buffers: ThreadLocal<RefCell<Vec<P>>>,
61 _phantom: PhantomData<(P, T)>,
62}
63
64impl<T: TowerFamily, P: PackedTop<T>> Default for FastCpuLayer<T, P> {
65 fn default() -> Self {
66 Self {
67 kernel_buffers: ThreadLocal::with_capacity(1 << get_log_max_threads()),
68 _phantom: PhantomData,
69 }
70 }
71}
72
73impl<T: TowerFamily, P: PackedTop<T>> ComputeLayer<T::B128> for FastCpuLayer<T, P> {
74 type Exec<'b> = FastCpuExecutor<'b, T, P>;
75 type DevMem = PackedMemory<P>;
76
77 fn copy_h2d(
78 &self,
79 src: &[T::B128],
80 dst: &mut FSliceMut<'_, T::B128, Self>,
81 ) -> Result<(), Error> {
82 if src.len() != dst.len() {
83 return Err(Error::InputValidation(
84 "precondition: src and dst buffers must have the same length".to_string(),
85 ));
86 }
87
88 unpack_if_possible_mut(
89 dst.as_slice_mut(),
90 |scalars| {
91 scalars[..src.len()].copy_from_slice(src);
92 Ok(())
93 },
94 |packed| {
95 src.par_chunks_exact(P::WIDTH)
96 .zip(packed.par_iter_mut())
97 .for_each(|(input, output)| {
98 *output = PackedField::from_scalars(input.iter().copied());
99 });
100
101 Ok(())
102 },
103 )
104 }
105
106 fn copy_d2h(&self, src: FSlice<'_, T::B128, Self>, dst: &mut [T::B128]) -> Result<(), Error> {
107 if src.len() != dst.len() {
108 return Err(Error::InputValidation(
109 "precondition: src and dst buffers must have the same length".to_string(),
110 ));
111 }
112
113 let dst = RefCell::new(dst);
114 unpack_if_possible(
115 src.as_slice(),
116 |scalars| {
117 dst.borrow_mut().copy_from_slice(&scalars[..src.len()]);
118 Ok(())
119 },
120 |packed: &[P]| {
121 (*dst.borrow_mut())
122 .par_chunks_exact_mut(P::WIDTH)
123 .zip(packed.par_iter())
124 .for_each(|(output, input)| {
125 for (input, output) in input.iter().zip(output.iter_mut()) {
126 *output = input;
127 }
128 });
129
130 for (input, output) in
131 PackedField::iter_slice(packed).zip(dst.borrow_mut().iter_mut())
132 {
133 *output = input;
134 }
135 Ok(())
136 },
137 )
138 }
139
140 fn copy_d2d(
141 &self,
142 src: FSlice<'_, T::B128, Self>,
143 dst: &mut FSliceMut<'_, T::B128, Self>,
144 ) -> Result<(), Error> {
145 if src.len() != dst.len() {
146 return Err(Error::InputValidation(
147 "precondition: src and dst buffers must have the same length".to_string(),
148 ));
149 }
150
151 dst.as_slice_mut().copy_from_slice(src.as_slice());
152
153 Ok(())
154 }
155
156 fn compile_expr(
157 &self,
158 expr: &ArithCircuit<T::B128>,
159 ) -> Result<<Self::Exec<'_> as ComputeLayerExecutor<T::B128>>::ExprEval, Error> {
160 let expr = ArithCircuitPoly::new(expr.clone());
161 Ok(expr)
162 }
163
164 fn execute<'a, 'b>(
165 &'b self,
166 f: impl FnOnce(&mut Self::Exec<'a>) -> Result<Vec<T::B128>, Error>,
167 ) -> Result<Vec<T::B128>, Error>
168 where
169 'b: 'a,
170 {
171 f(&mut FastCpuExecutor::<'a, T, P>::new(&self.kernel_buffers))
172 }
173
174 fn fill(
175 &self,
176 slice: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
177 value: T::B128,
178 ) -> Result<(), Error> {
179 match slice {
180 PackedMemorySliceMut::Slice(items) => {
181 items.fill(P::broadcast(value));
182 }
183 PackedMemorySliceMut::SingleElement { owned, .. } => {
184 owned.fill(value);
185 }
186 };
187 Ok(())
188 }
189}
190
191pub struct FastCpuExecutor<'a, T: TowerFamily, P: PackedTop<T>> {
192 kernel_buffers: &'a ThreadLocal<RefCell<Vec<P>>>,
193 _phantom_data: PhantomData<T>,
194}
195
196impl<'a, T: TowerFamily, P: PackedTop<T>> Clone for FastCpuExecutor<'a, T, P> {
197 fn clone(&self) -> Self {
198 Self {
199 kernel_buffers: self.kernel_buffers,
200 _phantom_data: PhantomData,
201 }
202 }
203}
204
205impl<'a, T: TowerFamily, P: PackedTop<T>> FastCpuExecutor<'a, T, P> {
206 pub fn new(kernel_buffers: &'a ThreadLocal<RefCell<Vec<P>>>) -> Self {
207 Self {
208 kernel_buffers,
209 _phantom_data: PhantomData,
210 }
211 }
212
213 fn process_kernels_chunks<'c, R: Send>(
214 &self,
215 map: impl Sync
216 + for<'b> Fn(
217 &'b mut FastKernelExecutor<T, P>,
218 usize,
219 Vec<KernelBuffer<'b, T::B128, PackedMemory<P>>>,
220 ) -> Result<R, Error>,
221 reduce_op: impl Sync + Fn(R, R) -> R,
222 mem_maps: Vec<KernelMemMap<'_, T::B128, PackedMemory<P>>>,
223 ) -> Result<Option<R>, Error> {
224 let log_chunks_range = KernelMemMap::log_chunks_range(&mem_maps)
225 .ok_or_else(|| Error::InputValidation("no chunks range found".to_string()))?;
226
227 let log_chunks = (get_log_max_threads() + 1)
229 .min(log_chunks_range.end)
230 .max(log_chunks_range.start);
231 let total_alloc = count_total_local_buffer_sizes(&mem_maps, log_chunks);
232
233 let mem_maps_count = mem_maps.len();
235 let mut memory_chunks = Vec::with_capacity(mem_maps_count << log_chunks);
240 let uninit = memory_chunks.spare_capacity_mut();
241 uninit
242 .par_chunks_exact_mut(1 << log_chunks)
243 .zip(mem_maps)
244 .for_each(|(chunk, mem_map)| {
245 for (input, out) in chunk.iter_mut().zip(mem_map.chunks(log_chunks)) {
246 input.write(out);
247 }
248 });
249 unsafe {
253 memory_chunks.set_len(mem_maps_count << log_chunks);
254 }
255 let memory_chunks_view = StridedArray2DViewMut::without_stride(
256 &mut memory_chunks,
257 mem_maps_count,
258 1 << log_chunks,
259 )
260 .expect("dimensions must be correct");
261
262 memory_chunks_view
263 .into_par_strides(1)
264 .map(|mut chunk| {
265 let buffer = self
266 .kernel_buffers
267 .get_or(|| RefCell::new(zeroed_vec(total_alloc)));
268 let mut buffer = buffer.borrow_mut();
269 if buffer.len() < total_alloc {
270 buffer.resize(total_alloc, P::zero());
271 }
272
273 let buffer = PackedMemorySliceMut::new_slice(&mut buffer);
274 let allocator = BumpAllocator::<T::B128, PackedMemory<P>>::new(buffer);
275
276 let kernel_data = chunk
277 .iter_column_mut(0)
278 .map(|mem_map| {
279 match std::mem::replace(mem_map, KernelMemMap::Local { log_size: 0 }) {
280 KernelMemMap::Chunked { data, .. } => KernelBuffer::Ref(data),
281 KernelMemMap::ChunkedMut { data, .. } => KernelBuffer::Mut(data),
282 KernelMemMap::Local { log_size } => {
283 let data = allocator
284 .alloc(1 << log_size)
285 .expect("buffer must be large enough");
286
287 KernelBuffer::Mut(data)
288 }
289 }
290 })
291 .collect::<Vec<_>>();
292
293 map(&mut FastKernelExecutor::default(), log_chunks, kernel_data)
294 })
295 .reduce_with(|lhs, rhs| Ok(reduce_op(lhs?, rhs?)))
296 .transpose()
297 }
298}
299
300impl<'a, T: TowerFamily, P: PackedTop<T>> ComputeLayerExecutor<T::B128>
301 for FastCpuExecutor<'a, T, P>
302{
303 type KernelExec = FastKernelExecutor<T, P>;
304 type DevMem = PackedMemory<P>;
305 type OpValue = T::B128;
306 type ExprEval = ArithCircuitPoly<T::B128>;
307
308 fn inner_product(
309 &mut self,
310 a_in: SubfieldSlice<'_, T::B128, Self::DevMem>,
311 b_in: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
312 ) -> Result<Self::OpValue, Error> {
313 if a_in.slice.len() << (<T::B128 as ExtensionField<T::B1>>::LOG_DEGREE - a_in.tower_level)
314 != b_in.len()
315 {
316 return Err(Error::InputValidation(
317 "precondition: a_in and b_in must have the same length".to_string(),
318 ));
319 }
320
321 fn inner_product_par_impl<FSub: Field, P: PackedExtension<FSub>>(
322 a_in: &[P],
323 b_in: &[P],
324 ) -> P::Scalar {
325 inner_product_par(b_in, PackedExtension::cast_bases(a_in))
326 }
327
328 let result = each_tower_subfield!(
329 a_in.tower_level,
330 T,
331 inner_product_par_impl::<_, P>(a_in.slice.as_slice(), b_in.as_slice())
332 );
333
334 Ok(result)
335 }
336
337 fn tensor_expand(
338 &mut self,
339 log_n: usize,
340 coordinates: &[T::B128],
341 data: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
342 ) -> Result<(), Error> {
343 tensor_prod_eq_ind(log_n, data.as_slice_mut(), coordinates)
344 .map_err(|_| Error::InputValidation("tensor dimensions are invalid".to_string()))
345 }
346
347 fn accumulate_kernels(
348 &mut self,
349 map: impl Sync
350 + for<'b> Fn(
351 &'b mut Self::KernelExec,
352 usize,
353 Vec<KernelBuffer<'b, T::B128, Self::DevMem>>,
354 ) -> Result<Vec<T::B128>, Error>,
355 mem_maps: Vec<KernelMemMap<'_, T::B128, Self::DevMem>>,
356 ) -> Result<Vec<Self::OpValue>, Error> {
357 self.process_kernels_chunks(
358 map,
359 |mut out1, out2| {
360 let mut out2_iter = out2.into_iter();
361 for (out1_i, out2_i) in std::iter::zip(&mut out1, &mut out2_iter) {
362 *out1_i += out2_i;
363 }
364 out1.extend(out2_iter);
365 out1
366 },
367 mem_maps,
368 )
369 .map(|opt| opt.unwrap_or_default())
370 }
371
372 fn map_kernels(
373 &mut self,
374 map: impl Sync
375 + for<'b> Fn(
376 &'b mut Self::KernelExec,
377 usize,
378 Vec<KernelBuffer<'b, T::B128, Self::DevMem>>,
379 ) -> Result<(), Error>,
380 mem_maps: Vec<KernelMemMap<'_, T::B128, Self::DevMem>>,
381 ) -> Result<(), Error> {
382 self.process_kernels_chunks(map, |_, _| {}, mem_maps)
383 .map(|_| ())
384 }
385
386 fn fold_left(
387 &mut self,
388 mat: SubfieldSlice<'_, T::B128, Self::DevMem>,
389 vec: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
390 out: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
391 ) -> Result<(), Error> {
392 let log_evals_size = strict_log_2(mat.len()).ok_or_else(|| {
393 Error::InputValidation("the length of `mat` must be a power of 2".to_string())
394 })?;
395 let log_query_size = strict_log_2(vec.len()).ok_or_else(|| {
396 Error::InputValidation("the length of `vec` must be a power of 2".to_string())
397 })?;
398
399 let out = binius_utils::mem::slice_uninit_mut(out.as_slice_mut());
400
401 fn fold_left<FSub: Field, P: PackedExtension<FSub>>(
402 mat: &[P],
403 log_evals_size: usize,
404 vec: &[P],
405 log_query_size: usize,
406 out: &mut [MaybeUninit<P>],
407 ) -> Result<(), Error> {
408 let mat = PackedExtension::cast_bases(mat);
409
410 binius_math::fold_left(mat, log_evals_size, vec, log_query_size, out).map_err(|_| {
411 Error::InputValidation("the input data dimensions are wrong".to_string())
412 })
413 }
414
415 each_tower_subfield!(
416 mat.tower_level,
417 T,
418 fold_left::<_, P>(
419 mat.slice.as_slice(),
420 log_evals_size,
421 vec.as_slice(),
422 log_query_size,
423 out,
424 )
425 )
426 }
427
428 fn fold_right(
429 &mut self,
430 mat: SubfieldSlice<'_, T::B128, Self::DevMem>,
431 vec: <Self::DevMem as binius_compute::memory::ComputeMemory<T::B128>>::FSlice<'_>,
432 out: &mut <Self::DevMem as binius_compute::memory::ComputeMemory<T::B128>>::FSliceMut<'_>,
433 ) -> Result<(), Error> {
434 let log_evals_size = strict_log_2(mat.len()).ok_or_else(|| {
435 Error::InputValidation("the length of `mat` must be a power of 2".to_string())
436 })?;
437 let log_query_size = strict_log_2(vec.len()).ok_or_else(|| {
438 Error::InputValidation("the length of `vec` must be a power of 2".to_string())
439 })?;
440
441 fn fold_right<FSub: Field, P: PackedExtension<FSub>>(
442 mat: &[P],
443 log_evals_size: usize,
444 vec: &[P],
445 log_query_size: usize,
446 out: &mut [P],
447 ) -> Result<(), Error> {
448 let mat = PackedExtension::cast_bases(mat);
449
450 binius_math::fold_right(mat, log_evals_size, vec, log_query_size, out).map_err(|_| {
451 Error::InputValidation("the input data dimensions are wrong".to_string())
452 })
453 }
454
455 each_tower_subfield!(
456 mat.tower_level,
457 T,
458 fold_right::<_, P>(
459 mat.slice.as_slice(),
460 log_evals_size,
461 vec.as_slice(),
462 log_query_size,
463 out.as_slice_mut()
464 )
465 )
466 }
467
468 fn fri_fold<FSub>(
469 &mut self,
470 ntt: &(impl AdditiveNTT<FSub> + Sync),
471 log_len: usize,
472 log_batch_size: usize,
473 challenges: &[T::B128],
474 data_in: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
475 data_out: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
476 ) -> Result<(), Error>
477 where
478 FSub: binius_field::BinaryField,
479 T::B128: binius_field::ExtensionField<FSub>,
480 {
481 unpack_if_possible_mut(
482 data_out.as_slice_mut(),
483 |out| {
484 fold_interleaved_allocated(
485 ntt,
486 data_in.as_slice(),
487 challenges,
488 log_len,
489 log_batch_size,
490 out,
491 );
492 },
493 |packed| {
494 let mut out_scalars =
495 zeroed_vec(1 << (log_len - (challenges.len() - log_batch_size)));
496 fold_interleaved_allocated(
497 ntt,
498 packed,
499 challenges,
500 log_len,
501 log_batch_size,
502 &mut out_scalars,
503 );
504
505 let mut iter = out_scalars.iter().copied();
506 for p in packed {
507 *p = PackedField::from_scalars(&mut iter);
508 }
509 },
510 );
511
512 Ok(())
513 }
514
515 fn extrapolate_line(
516 &mut self,
517 evals_0: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
518 evals_1: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
519 z: T::B128,
520 ) -> Result<(), Error> {
521 if evals_0.len() != evals_1.len() {
522 return Err(Error::InputValidation(
523 "precondition: evals_0 and evals_1 must have the same length".to_string(),
524 ));
525 }
526
527 if try_extrapolate_line_byte_sliced::<_, PackedBinaryField1x128b>(
528 evals_0.as_slice_mut(),
529 evals_1.as_slice(),
530 z,
531 ) || try_extrapolate_line_byte_sliced::<_, PackedBinaryField2x128b>(
532 evals_0.as_slice_mut(),
533 evals_1.as_slice(),
534 z,
535 ) || try_extrapolate_line_byte_sliced::<_, PackedBinaryField4x128b>(
536 evals_0.as_slice_mut(),
537 evals_1.as_slice(),
538 z,
539 ) {
540 } else {
541 let z = P::broadcast(z);
542 evals_0
543 .as_slice_mut()
544 .par_iter_mut()
545 .zip(evals_1.as_slice().par_iter())
546 .for_each(|(x0, x1)| *x0 += (*x1 - *x0) * z);
547 }
548
549 Ok(())
550 }
551
552 fn compute_composite(
553 &mut self,
554 inputs: &SlicesBatch<<Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>>,
555 output: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
556 composition: &ArithCircuitPoly<T::B128>,
557 ) -> Result<(), Error> {
558 if inputs.row_len() != output.len() {
559 return Err(Error::InputValidation("inputs and output must be the same length".into()));
560 }
561
562 if CompositionPoly::<P>::n_vars(composition) != inputs.n_rows() {
563 return Err(Error::InputValidation("composition not match with inputs".into()));
564 }
565
566 let rows = inputs
567 .iter()
568 .map(|slice| slice.as_slice())
569 .collect::<Vec<_>>();
570
571 let log_chunks = get_log_max_threads() + 1;
572
573 let chunk_size = (output.len() >> log_chunks).max(1);
574
575 let packed_row_len = checked_int_div(inputs.row_len(), P::WIDTH);
576
577 let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, packed_row_len) };
578
579 output
580 .as_slice_mut()
581 .par_chunks_mut(chunk_size)
582 .enumerate()
583 .for_each(|(chunk_idx, output_chunk)| {
584 let offset = chunk_idx * chunk_size;
585 let rows = rows_batch.columns_subrange(offset..offset + chunk_size);
586
587 composition
588 .batch_evaluate(&rows, output_chunk)
589 .expect("dimensions are correct");
590 });
591
592 Ok(())
593 }
594
595 fn join<Out1: Send, Out2: Send>(
596 &mut self,
597 op1: impl Send + FnOnce(&mut Self) -> Result<Out1, Error>,
598 op2: impl Send + FnOnce(&mut Self) -> Result<Out2, Error>,
599 ) -> Result<(Out1, Out2), Error> {
600 let (out1, out2) =
601 binius_maybe_rayon::join(|| op1(&mut self.clone()), || op2(&mut self.clone()));
602
603 Ok((out1?, out2?))
604 }
605
606 fn map<Out: Send, I: ExactSizeIterator<Item: Send> + Send>(
607 &mut self,
608 iter: I,
609 map: impl Sync + Fn(&mut Self, I::Item) -> Result<Out, Error>,
610 ) -> Result<Vec<Out>, Error> {
611 let mut result = iter
614 .enumerate()
615 .par_bridge()
616 .map(|(index, item)| (index, map(&mut self.clone(), item)))
617 .collect::<Vec<_>>();
618 result.sort_unstable_by_key(|(index, _)| *index);
619
620 result.into_iter().map(|(_, out)| out).collect()
621 }
622
623 fn pairwise_product_reduce(
624 &mut self,
625 _input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
626 _round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
627 ) -> Result<(), Error> {
628 todo!()
630 }
631}
632
633#[inline(always)]
639fn try_extrapolate_line_byte_sliced<P1, P2>(
640 evals_0: &mut [P1],
641 evals_1: &[P1],
642 z: P1::Scalar,
643) -> bool
644where
645 P1: PackedField,
646 P2: PackedField<Scalar = BinaryField128b> + WithUnderlier,
647 P2::Underlier: UnderlierWithBitOps
648 + PackScalar<BinaryField128b, Packed = P2>
649 + PackScalar<AESTowerField128b>
650 + PackScalar<BinaryField8b>
651 + PackScalar<AESTowerField8b>
652 + From<u8>
653 + Pod,
654 u8: NumCast<P2::Underlier>,
655 ByteSlicedUnderlier<P2::Underlier, 16>: PackScalar<AESTowerField128b, Packed: Pod>,
656 PackedType<P2::Underlier, BinaryField8b>:
657 PackedTransformationFactory<PackedType<P2::Underlier, AESTowerField8b>>,
658 PackedType<P2::Underlier, AESTowerField8b>:
659 PackedTransformationFactory<PackedType<P2::Underlier, BinaryField8b>>,
660{
661 if TypeId::of::<P1>() == TypeId::of::<P2>() {
662 extrapolate_line_byte_sliced::<P2::Underlier>(
664 unsafe { transmute::<&mut [P1], &mut [P2]>(evals_0) },
665 unsafe { transmute::<&[P1], &[P2]>(evals_1) },
666 *unsafe { transmute::<&P1::Scalar, &BinaryField128b>(&z) },
667 );
668
669 true
670 } else {
671 false
672 }
673}
674
675fn extrapolate_line_byte_sliced<Underlier>(
678 evals_0: &mut [PackedType<Underlier, BinaryField128b>],
679 evals_1: &[PackedType<Underlier, BinaryField128b>],
680 z: BinaryField128b,
681) where
682 Underlier: UnderlierWithBitOps
683 + PackScalar<BinaryField128b>
684 + PackScalar<AESTowerField128b>
685 + PackScalar<BinaryField8b>
686 + PackScalar<AESTowerField8b>
687 + From<u8>
688 + Pod,
689 u8: NumCast<Underlier>,
690 ByteSlicedUnderlier<Underlier, 16>: PackScalar<AESTowerField128b, Packed: Pod>,
691 PackedType<Underlier, BinaryField8b>:
692 PackedTransformationFactory<PackedType<Underlier, AESTowerField8b>>,
693 PackedType<Underlier, AESTowerField8b>:
694 PackedTransformationFactory<PackedType<Underlier, BinaryField8b>>,
695{
696 let fwd_transform = make_binary_to_aes_packed_transformer::<
697 PackedType<Underlier, BinaryField128b>,
698 PackedType<Underlier, AESTowerField128b>,
699 >();
700 let inv_transform = make_aes_to_binary_packed_transformer::<
701 PackedType<Underlier, AESTowerField128b>,
702 PackedType<Underlier, BinaryField128b>,
703 >();
704
705 const BYTES_COUNT: usize = 16;
707 let byte_sliced_z =
708 PackedType::<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>::broadcast(z.into());
709 evals_0
710 .par_chunks_exact_mut(BYTES_COUNT)
711 .zip(evals_1.par_chunks_exact(BYTES_COUNT))
712 .for_each(|(x0, x1)| {
713 let mut x0_aes =
715 [MaybeUninit::<PackedType<Underlier, AESTowerField128b>>::uninit(); BYTES_COUNT];
716 for (x0_aes, x0) in x0_aes.iter_mut().zip(x0.iter()) {
717 _ = *x0_aes.write(fwd_transform.transform(x0));
718 }
719 let x0_aes = unsafe {
721 transmute::<
722 &mut [MaybeUninit<PackedType<Underlier, AESTowerField128b>>; BYTES_COUNT],
723 &mut [Underlier; BYTES_COUNT],
724 >(&mut x0_aes)
725 };
726 Underlier::transpose_bytes_to_byte_sliced::<TowerLevel16>(x0_aes);
727
728 let mut x1_aes =
730 [MaybeUninit::<PackedType<Underlier, AESTowerField128b>>::uninit(); BYTES_COUNT];
731 for (x1_aes, x1) in x1_aes.iter_mut().zip(x1.iter()) {
732 _ = *x1_aes.write(fwd_transform.transform(x1));
733 }
734 let x1_aes = unsafe {
736 transmute::<
737 &mut [MaybeUninit<PackedType<Underlier, AESTowerField128b>>; BYTES_COUNT],
738 &mut [Underlier; BYTES_COUNT],
739 >(&mut x1_aes)
740 };
741 Underlier::transpose_bytes_to_byte_sliced::<TowerLevel16>(x1_aes);
742
743 {
745 let x0_bytes_sliced = bytemuck::must_cast_mut::<
746 _,
747 PackedType<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>,
748 >(x0_aes);
749 let x1_bytes_sliced = bytemuck::must_cast_ref::<
750 _,
751 PackedType<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>,
752 >(x1_aes);
753
754 *x0_bytes_sliced += (*x1_bytes_sliced - *x0_bytes_sliced) * byte_sliced_z;
755 }
756
757 Underlier::transpose_bytes_from_byte_sliced::<TowerLevel16>(x0_aes);
759 for (x0, x0_aes) in x0.iter_mut().zip(x0_aes.iter()) {
760 *x0 = inv_transform.transform(
761 PackedType::<Underlier, AESTowerField128b>::from_underlier_ref(x0_aes),
762 );
763 }
764 });
765
766 let packed_z = PackedType::<Underlier, BinaryField128b>::broadcast(z);
768 for (x0, x1) in evals_0
769 .chunks_exact_mut(BYTES_COUNT)
770 .into_remainder()
771 .iter_mut()
772 .zip(evals_1.chunks_exact(BYTES_COUNT).remainder())
773 {
774 *x0 += (*x1 - *x0) * packed_z;
775 }
776}
777
778#[derive(Debug)]
779pub struct FastKernelExecutor<T, P>(PhantomData<(T, P)>);
780
781impl<T, P> Default for FastKernelExecutor<T, P> {
782 fn default() -> Self {
783 Self(PhantomData)
784 }
785}
786
787impl<T: TowerFamily, P: PackedTop<T>> KernelExecutor<T::B128> for FastKernelExecutor<T, P> {
788 type Mem = PackedMemory<P>;
789 type Value = T::B128;
790 type ExprEval = ArithCircuitPoly<T::B128>;
791
792 #[inline(always)]
793 fn decl_value(&mut self, init: T::B128) -> Result<Self::Value, Error> {
794 Ok(init)
795 }
796
797 fn sum_composition_evals(
798 &mut self,
799 inputs: &SlicesBatch<<Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>>,
800 composition: &Self::ExprEval,
801 batch_coeff: T::B128,
802 accumulator: &mut Self::Value,
803 ) -> Result<(), Error> {
804 const BATCH_SIZE: usize = 64;
809
810 let rows = inputs
811 .iter()
812 .map(|slice| slice.as_slice())
813 .collect::<Vec<_>>();
814 if inputs.row_len() >= P::WIDTH {
815 let packed_row_len = checked_int_div(inputs.row_len(), P::WIDTH);
816
817 let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, packed_row_len) };
820 let mut result = P::zero();
821 let mut output = [P::zero(); BATCH_SIZE];
822 for offset in (0..packed_row_len).step_by(BATCH_SIZE) {
823 let batch_size = packed_row_len.saturating_sub(offset).min(BATCH_SIZE);
824 let rows = rows_batch.columns_subrange(offset..offset + batch_size);
825 composition
826 .batch_evaluate(&rows, &mut output[..batch_size])
827 .expect("dimensions are correct");
828
829 result += output[..batch_size].iter().copied().sum::<P>();
830 }
831
832 *accumulator += batch_coeff * result.into_iter().sum::<T::B128>();
833 } else {
834 let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, 1) };
835
836 let mut output = P::zero();
837 composition
838 .batch_evaluate(&rows_batch, slice::from_mut(&mut output))
839 .expect("dimensions are correct");
840
841 *accumulator +=
842 batch_coeff * output.into_iter().take(inputs.row_len()).sum::<T::B128>();
843 }
844
845 Ok(())
846 }
847
848 fn add(
849 &mut self,
850 log_len: usize,
851 src1: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
852 src2: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
853 dst: &mut <Self::Mem as ComputeMemory<T::B128>>::FSliceMut<'_>,
854 ) -> Result<(), Error> {
855 if src1.len() != 1 << log_len {
856 return Err(Error::InputValidation(
857 "src1 length must be equal to 2^log_len".to_string(),
858 ));
859 }
860 if src2.len() != 1 << log_len {
861 return Err(Error::InputValidation(
862 "src2 length must be equal to 2^log_len".to_string(),
863 ));
864 }
865 if dst.len() != 1 << log_len {
866 return Err(Error::InputValidation(
867 "dst length must be equal to 2^log_len".to_string(),
868 ));
869 }
870
871 for (dst_i, &src1_i, &src2_i) in
872 izip!(dst.as_slice_mut().iter_mut(), src1.as_slice(), src2.as_slice())
873 {
874 *dst_i = src1_i + src2_i;
875 }
876
877 Ok(())
878 }
879
880 fn add_assign(
881 &mut self,
882 log_len: usize,
883 src: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
884 dst: &mut <Self::Mem as ComputeMemory<T::B128>>::FSliceMut<'_>,
885 ) -> Result<(), Error> {
886 if src.len() != 1 << log_len {
887 return Err(Error::InputValidation(
888 "src1 length must be equal to 2^log_len".to_string(),
889 ));
890 }
891 if dst.len() != 1 << log_len {
892 return Err(Error::InputValidation(
893 "dst length must be equal to 2^log_len".to_string(),
894 ));
895 }
896
897 for (dst_i, &src_i) in zip(dst.as_slice_mut().iter_mut(), src.as_slice()) {
898 *dst_i += src_i;
899 }
900
901 Ok(())
902 }
903}
904
905pub struct FastCpuLayerHolder<T: TowerFamily, P: PackedTop<T>> {
906 layer: FastCpuLayer<T, P>,
907 host_mem: Vec<T::B128>,
908 dev_mem: Vec<P>,
909}
910
911impl<T: TowerFamily, P: PackedTop<T>> FastCpuLayerHolder<T, P> {
912 pub fn new(host_mem_size: usize, dev_mem_size: usize) -> Self {
913 let layer = FastCpuLayer::default();
914 let host_mem = vec![T::B128::zero(); host_mem_size];
915 let dev_mem = vec![P::zero(); (dev_mem_size >> P::LOG_WIDTH).max(1)];
916
917 Self {
918 layer,
919 host_mem,
920 dev_mem,
921 }
922 }
923}
924
925impl<T: TowerFamily, P: PackedTop<T>> ComputeHolder<T::B128, FastCpuLayer<T, P>>
926 for FastCpuLayerHolder<T, P>
927{
928 type HostComputeAllocator<'a> = HostBumpAllocator<'a, T::B128>;
929 type DeviceComputeAllocator<'a> =
930 BumpAllocator<'a, T::B128, <FastCpuLayer<T, P> as ComputeLayer<T::B128>>::DevMem>;
931
932 fn to_data<'a, 'b>(
933 &'a mut self,
934 ) -> ComputeData<
935 'a,
936 T::B128,
937 FastCpuLayer<T, P>,
938 Self::HostComputeAllocator<'b>,
939 Self::DeviceComputeAllocator<'b>,
940 >
941 where
942 'a: 'b,
943 {
944 ComputeData::new(
945 &self.layer,
946 BumpAllocator::new(self.host_mem.as_mut_slice()),
947 BumpAllocator::new(PackedMemorySliceMut::new_slice(&mut self.dev_mem)),
948 )
949 }
950}