binius_fast_compute/
layer.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{
4	any::TypeId,
5	cell::RefCell,
6	iter::zip,
7	marker::PhantomData,
8	mem::{MaybeUninit, transmute},
9	slice,
10};
11
12use binius_compute::{
13	ComputeData, ComputeHolder, ComputeLayerExecutor, KernelExecutor,
14	alloc::{BumpAllocator, ComputeAllocator, HostBumpAllocator},
15	cpu::layer::count_total_local_buffer_sizes,
16	each_generic_tower_subfield as each_tower_subfield,
17	layer::{ComputeLayer, Error, FSlice, FSliceMut, KernelBuffer, KernelMemMap},
18	memory::{ComputeMemory, SizedSlice, SlicesBatch, SubfieldSlice},
19};
20use binius_field::{
21	AESTowerField8b, AESTowerField128b, BinaryField8b, BinaryField128b, ByteSlicedUnderlier,
22	ExtensionField, Field, PackedBinaryField1x128b, PackedBinaryField2x128b,
23	PackedBinaryField4x128b, PackedExtension, PackedField,
24	as_packed_field::{PackScalar, PackedType},
25	linear_transformation::{PackedTransformationFactory, Transformation},
26	make_aes_to_binary_packed_transformer, make_binary_to_aes_packed_transformer,
27	tower::{PackedTop, TowerFamily},
28	tower_levels::TowerLevel16,
29	underlier::{NumCast, UnderlierWithBitOps, WithUnderlier},
30	unpack_if_possible, unpack_if_possible_mut,
31	util::inner_product_par,
32};
33use binius_math::{ArithCircuit, CompositionPoly, RowsBatchRef, tensor_prod_eq_ind};
34use binius_maybe_rayon::{
35	iter::{
36		IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
37		ParallelIterator,
38	},
39	prelude::ParallelBridge,
40	slice::{ParallelSlice, ParallelSliceMut},
41};
42use binius_ntt::{AdditiveNTT, fri::fold_interleaved_allocated};
43use binius_utils::{
44	checked_arithmetics::{checked_int_div, strict_log_2},
45	rayon::get_log_max_threads,
46	strided_array::StridedArray2DViewMut,
47};
48use bytemuck::{Pod, zeroed_vec};
49use itertools::izip;
50use thread_local::ThreadLocal;
51
52use crate::{
53	arith_circuit::ArithCircuitPoly,
54	memory::{PackedMemory, PackedMemorySliceMut},
55};
56
57/// Optimized CPU implementation of the compute layer.
58#[derive(Debug)]
59pub struct FastCpuLayer<T: TowerFamily, P: PackedTop<T>> {
60	kernel_buffers: ThreadLocal<RefCell<Vec<P>>>,
61	_phantom: PhantomData<(P, T)>,
62}
63
64impl<T: TowerFamily, P: PackedTop<T>> Default for FastCpuLayer<T, P> {
65	fn default() -> Self {
66		Self {
67			kernel_buffers: ThreadLocal::with_capacity(1 << get_log_max_threads()),
68			_phantom: PhantomData,
69		}
70	}
71}
72
73impl<T: TowerFamily, P: PackedTop<T>> ComputeLayer<T::B128> for FastCpuLayer<T, P> {
74	type Exec<'b> = FastCpuExecutor<'b, T, P>;
75	type DevMem = PackedMemory<P>;
76
77	fn copy_h2d(
78		&self,
79		src: &[T::B128],
80		dst: &mut FSliceMut<'_, T::B128, Self>,
81	) -> Result<(), Error> {
82		if src.len() != dst.len() {
83			return Err(Error::InputValidation(
84				"precondition: src and dst buffers must have the same length".to_string(),
85			));
86		}
87
88		unpack_if_possible_mut(
89			dst.as_slice_mut(),
90			|scalars| {
91				scalars[..src.len()].copy_from_slice(src);
92				Ok(())
93			},
94			|packed| {
95				src.par_chunks_exact(P::WIDTH)
96					.zip(packed.par_iter_mut())
97					.for_each(|(input, output)| {
98						*output = PackedField::from_scalars(input.iter().copied());
99					});
100
101				Ok(())
102			},
103		)
104	}
105
106	fn copy_d2h(&self, src: FSlice<'_, T::B128, Self>, dst: &mut [T::B128]) -> Result<(), Error> {
107		if src.len() != dst.len() {
108			return Err(Error::InputValidation(
109				"precondition: src and dst buffers must have the same length".to_string(),
110			));
111		}
112
113		let dst = RefCell::new(dst);
114		unpack_if_possible(
115			src.as_slice(),
116			|scalars| {
117				dst.borrow_mut().copy_from_slice(&scalars[..src.len()]);
118				Ok(())
119			},
120			|packed: &[P]| {
121				(*dst.borrow_mut())
122					.par_chunks_exact_mut(P::WIDTH)
123					.zip(packed.par_iter())
124					.for_each(|(output, input)| {
125						for (input, output) in input.iter().zip(output.iter_mut()) {
126							*output = input;
127						}
128					});
129
130				for (input, output) in
131					PackedField::iter_slice(packed).zip(dst.borrow_mut().iter_mut())
132				{
133					*output = input;
134				}
135				Ok(())
136			},
137		)
138	}
139
140	fn copy_d2d(
141		&self,
142		src: FSlice<'_, T::B128, Self>,
143		dst: &mut FSliceMut<'_, T::B128, Self>,
144	) -> Result<(), Error> {
145		if src.len() != dst.len() {
146			return Err(Error::InputValidation(
147				"precondition: src and dst buffers must have the same length".to_string(),
148			));
149		}
150
151		dst.as_slice_mut().copy_from_slice(src.as_slice());
152
153		Ok(())
154	}
155
156	fn compile_expr(
157		&self,
158		expr: &ArithCircuit<T::B128>,
159	) -> Result<<Self::Exec<'_> as ComputeLayerExecutor<T::B128>>::ExprEval, Error> {
160		let expr = ArithCircuitPoly::new(expr.clone());
161		Ok(expr)
162	}
163
164	fn execute<'a, 'b>(
165		&'b self,
166		f: impl FnOnce(&mut Self::Exec<'a>) -> Result<Vec<T::B128>, Error>,
167	) -> Result<Vec<T::B128>, Error>
168	where
169		'b: 'a,
170	{
171		f(&mut FastCpuExecutor::<'a, T, P>::new(&self.kernel_buffers))
172	}
173
174	fn fill(
175		&self,
176		slice: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
177		value: T::B128,
178	) -> Result<(), Error> {
179		match slice {
180			PackedMemorySliceMut::Slice(items) => {
181				items.fill(P::broadcast(value));
182			}
183			PackedMemorySliceMut::SingleElement { owned, .. } => {
184				owned.fill(value);
185			}
186		};
187		Ok(())
188	}
189}
190
191pub struct FastCpuExecutor<'a, T: TowerFamily, P: PackedTop<T>> {
192	kernel_buffers: &'a ThreadLocal<RefCell<Vec<P>>>,
193	_phantom_data: PhantomData<T>,
194}
195
196impl<'a, T: TowerFamily, P: PackedTop<T>> Clone for FastCpuExecutor<'a, T, P> {
197	fn clone(&self) -> Self {
198		Self {
199			kernel_buffers: self.kernel_buffers,
200			_phantom_data: PhantomData,
201		}
202	}
203}
204
205impl<'a, T: TowerFamily, P: PackedTop<T>> FastCpuExecutor<'a, T, P> {
206	pub fn new(kernel_buffers: &'a ThreadLocal<RefCell<Vec<P>>>) -> Self {
207		Self {
208			kernel_buffers,
209			_phantom_data: PhantomData,
210		}
211	}
212
213	fn process_kernels_chunks<'c, R: Send>(
214		&self,
215		map: impl Sync
216		+ for<'b> Fn(
217			&'b mut FastKernelExecutor<T, P>,
218			usize,
219			Vec<KernelBuffer<'b, T::B128, PackedMemory<P>>>,
220		) -> Result<R, Error>,
221		reduce_op: impl Sync + Fn(R, R) -> R,
222		mem_maps: Vec<KernelMemMap<'_, T::B128, PackedMemory<P>>>,
223	) -> Result<Option<R>, Error> {
224		let log_chunks_range = KernelMemMap::log_chunks_range(&mem_maps)
225			.ok_or_else(|| Error::InputValidation("no chunks range found".to_string()))?;
226
227		// Choose the number of chunks based on the range and the number of threads available.
228		let log_chunks = (get_log_max_threads() + 1)
229			.min(log_chunks_range.end)
230			.max(log_chunks_range.start);
231		let total_alloc = count_total_local_buffer_sizes(&mem_maps, log_chunks);
232
233		// Initialize the kernel memory for each chunk.
234		let mem_maps_count = mem_maps.len();
235		// We store chunks in a way [mem_map_0_chunk_0, mem_map_0_chunk_1, ..., mem_map_0_chunk_N,
236		// mem_map_1_chunk_0, ...] This allows us to initialize `memory_chunks` in parallel which
237		// is faster when `mem_maps_count << log_chunks` is big. As a downside, we have to use
238		// `StridedArray2DViewMut` to access the memory for different chunks later.
239		let mut memory_chunks = Vec::with_capacity(mem_maps_count << log_chunks);
240		let uninit = memory_chunks.spare_capacity_mut();
241		uninit
242			.par_chunks_exact_mut(1 << log_chunks)
243			.zip(mem_maps)
244			.for_each(|(chunk, mem_map)| {
245				for (input, out) in chunk.iter_mut().zip(mem_map.chunks(log_chunks)) {
246					input.write(out);
247				}
248			});
249		// Safety:
250		// - `memory_chunks` is initialized with `mem_maps_count << log_chunks` elements.
251		// - Each element is initialized by the previous `for_each` loop.
252		unsafe {
253			memory_chunks.set_len(mem_maps_count << log_chunks);
254		}
255		let memory_chunks_view = StridedArray2DViewMut::without_stride(
256			&mut memory_chunks,
257			mem_maps_count,
258			1 << log_chunks,
259		)
260		.expect("dimensions must be correct");
261
262		memory_chunks_view
263			.into_par_strides(1)
264			.map(|mut chunk| {
265				let buffer = self
266					.kernel_buffers
267					.get_or(|| RefCell::new(zeroed_vec(total_alloc)));
268				let mut buffer = buffer.borrow_mut();
269				if buffer.len() < total_alloc {
270					buffer.resize(total_alloc, P::zero());
271				}
272
273				let buffer = PackedMemorySliceMut::new_slice(&mut buffer);
274				let allocator = BumpAllocator::<T::B128, PackedMemory<P>>::new(buffer);
275
276				let kernel_data = chunk
277					.iter_column_mut(0)
278					.map(|mem_map| {
279						match std::mem::replace(mem_map, KernelMemMap::Local { log_size: 0 }) {
280							KernelMemMap::Chunked { data, .. } => KernelBuffer::Ref(data),
281							KernelMemMap::ChunkedMut { data, .. } => KernelBuffer::Mut(data),
282							KernelMemMap::Local { log_size } => {
283								let data = allocator
284									.alloc(1 << log_size)
285									.expect("buffer must be large enough");
286
287								KernelBuffer::Mut(data)
288							}
289						}
290					})
291					.collect::<Vec<_>>();
292
293				map(&mut FastKernelExecutor::default(), log_chunks, kernel_data)
294			})
295			.reduce_with(|lhs, rhs| Ok(reduce_op(lhs?, rhs?)))
296			.transpose()
297	}
298}
299
300impl<'a, T: TowerFamily, P: PackedTop<T>> ComputeLayerExecutor<T::B128>
301	for FastCpuExecutor<'a, T, P>
302{
303	type KernelExec = FastKernelExecutor<T, P>;
304	type DevMem = PackedMemory<P>;
305	type OpValue = T::B128;
306	type ExprEval = ArithCircuitPoly<T::B128>;
307
308	fn inner_product(
309		&mut self,
310		a_in: SubfieldSlice<'_, T::B128, Self::DevMem>,
311		b_in: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
312	) -> Result<Self::OpValue, Error> {
313		if a_in.slice.len() << (<T::B128 as ExtensionField<T::B1>>::LOG_DEGREE - a_in.tower_level)
314			!= b_in.len()
315		{
316			return Err(Error::InputValidation(
317				"precondition: a_in and b_in must have the same length".to_string(),
318			));
319		}
320
321		fn inner_product_par_impl<FSub: Field, P: PackedExtension<FSub>>(
322			a_in: &[P],
323			b_in: &[P],
324		) -> P::Scalar {
325			inner_product_par(b_in, PackedExtension::cast_bases(a_in))
326		}
327
328		let result = each_tower_subfield!(
329			a_in.tower_level,
330			T,
331			inner_product_par_impl::<_, P>(a_in.slice.as_slice(), b_in.as_slice())
332		);
333
334		Ok(result)
335	}
336
337	fn tensor_expand(
338		&mut self,
339		log_n: usize,
340		coordinates: &[T::B128],
341		data: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
342	) -> Result<(), Error> {
343		tensor_prod_eq_ind(log_n, data.as_slice_mut(), coordinates)
344			.map_err(|_| Error::InputValidation("tensor dimensions are invalid".to_string()))
345	}
346
347	fn accumulate_kernels(
348		&mut self,
349		map: impl Sync
350		+ for<'b> Fn(
351			&'b mut Self::KernelExec,
352			usize,
353			Vec<KernelBuffer<'b, T::B128, Self::DevMem>>,
354		) -> Result<Vec<T::B128>, Error>,
355		mem_maps: Vec<KernelMemMap<'_, T::B128, Self::DevMem>>,
356	) -> Result<Vec<Self::OpValue>, Error> {
357		self.process_kernels_chunks(
358			map,
359			|mut out1, out2| {
360				let mut out2_iter = out2.into_iter();
361				for (out1_i, out2_i) in std::iter::zip(&mut out1, &mut out2_iter) {
362					*out1_i += out2_i;
363				}
364				out1.extend(out2_iter);
365				out1
366			},
367			mem_maps,
368		)
369		.map(|opt| opt.unwrap_or_default())
370	}
371
372	fn map_kernels(
373		&mut self,
374		map: impl Sync
375		+ for<'b> Fn(
376			&'b mut Self::KernelExec,
377			usize,
378			Vec<KernelBuffer<'b, T::B128, Self::DevMem>>,
379		) -> Result<(), Error>,
380		mem_maps: Vec<KernelMemMap<'_, T::B128, Self::DevMem>>,
381	) -> Result<(), Error> {
382		self.process_kernels_chunks(map, |_, _| {}, mem_maps)
383			.map(|_| ())
384	}
385
386	fn fold_left(
387		&mut self,
388		mat: SubfieldSlice<'_, T::B128, Self::DevMem>,
389		vec: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
390		out: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
391	) -> Result<(), Error> {
392		let log_evals_size = strict_log_2(mat.len()).ok_or_else(|| {
393			Error::InputValidation("the length of `mat` must be a power of 2".to_string())
394		})?;
395		let log_query_size = strict_log_2(vec.len()).ok_or_else(|| {
396			Error::InputValidation("the length of `vec` must be a power of 2".to_string())
397		})?;
398
399		let out = binius_utils::mem::slice_uninit_mut(out.as_slice_mut());
400
401		fn fold_left<FSub: Field, P: PackedExtension<FSub>>(
402			mat: &[P],
403			log_evals_size: usize,
404			vec: &[P],
405			log_query_size: usize,
406			out: &mut [MaybeUninit<P>],
407		) -> Result<(), Error> {
408			let mat = PackedExtension::cast_bases(mat);
409
410			binius_math::fold_left(mat, log_evals_size, vec, log_query_size, out).map_err(|_| {
411				Error::InputValidation("the input data dimensions are wrong".to_string())
412			})
413		}
414
415		each_tower_subfield!(
416			mat.tower_level,
417			T,
418			fold_left::<_, P>(
419				mat.slice.as_slice(),
420				log_evals_size,
421				vec.as_slice(),
422				log_query_size,
423				out,
424			)
425		)
426	}
427
428	fn fold_right(
429		&mut self,
430		mat: SubfieldSlice<'_, T::B128, Self::DevMem>,
431		vec: <Self::DevMem as binius_compute::memory::ComputeMemory<T::B128>>::FSlice<'_>,
432		out: &mut <Self::DevMem as binius_compute::memory::ComputeMemory<T::B128>>::FSliceMut<'_>,
433	) -> Result<(), Error> {
434		let log_evals_size = strict_log_2(mat.len()).ok_or_else(|| {
435			Error::InputValidation("the length of `mat` must be a power of 2".to_string())
436		})?;
437		let log_query_size = strict_log_2(vec.len()).ok_or_else(|| {
438			Error::InputValidation("the length of `vec` must be a power of 2".to_string())
439		})?;
440
441		fn fold_right<FSub: Field, P: PackedExtension<FSub>>(
442			mat: &[P],
443			log_evals_size: usize,
444			vec: &[P],
445			log_query_size: usize,
446			out: &mut [P],
447		) -> Result<(), Error> {
448			let mat = PackedExtension::cast_bases(mat);
449
450			binius_math::fold_right(mat, log_evals_size, vec, log_query_size, out).map_err(|_| {
451				Error::InputValidation("the input data dimensions are wrong".to_string())
452			})
453		}
454
455		each_tower_subfield!(
456			mat.tower_level,
457			T,
458			fold_right::<_, P>(
459				mat.slice.as_slice(),
460				log_evals_size,
461				vec.as_slice(),
462				log_query_size,
463				out.as_slice_mut()
464			)
465		)
466	}
467
468	fn fri_fold<FSub>(
469		&mut self,
470		ntt: &(impl AdditiveNTT<FSub> + Sync),
471		log_len: usize,
472		log_batch_size: usize,
473		challenges: &[T::B128],
474		data_in: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
475		data_out: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
476	) -> Result<(), Error>
477	where
478		FSub: binius_field::BinaryField,
479		T::B128: binius_field::ExtensionField<FSub>,
480	{
481		unpack_if_possible_mut(
482			data_out.as_slice_mut(),
483			|out| {
484				fold_interleaved_allocated(
485					ntt,
486					data_in.as_slice(),
487					challenges,
488					log_len,
489					log_batch_size,
490					out,
491				);
492			},
493			|packed| {
494				let mut out_scalars =
495					zeroed_vec(1 << (log_len - (challenges.len() - log_batch_size)));
496				fold_interleaved_allocated(
497					ntt,
498					packed,
499					challenges,
500					log_len,
501					log_batch_size,
502					&mut out_scalars,
503				);
504
505				let mut iter = out_scalars.iter().copied();
506				for p in packed {
507					*p = PackedField::from_scalars(&mut iter);
508				}
509			},
510		);
511
512		Ok(())
513	}
514
515	fn extrapolate_line(
516		&mut self,
517		evals_0: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
518		evals_1: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
519		z: T::B128,
520	) -> Result<(), Error> {
521		if evals_0.len() != evals_1.len() {
522			return Err(Error::InputValidation(
523				"precondition: evals_0 and evals_1 must have the same length".to_string(),
524			));
525		}
526
527		if try_extrapolate_line_byte_sliced::<_, PackedBinaryField1x128b>(
528			evals_0.as_slice_mut(),
529			evals_1.as_slice(),
530			z,
531		) || try_extrapolate_line_byte_sliced::<_, PackedBinaryField2x128b>(
532			evals_0.as_slice_mut(),
533			evals_1.as_slice(),
534			z,
535		) || try_extrapolate_line_byte_sliced::<_, PackedBinaryField4x128b>(
536			evals_0.as_slice_mut(),
537			evals_1.as_slice(),
538			z,
539		) {
540		} else {
541			let z = P::broadcast(z);
542			evals_0
543				.as_slice_mut()
544				.par_iter_mut()
545				.zip(evals_1.as_slice().par_iter())
546				.for_each(|(x0, x1)| *x0 += (*x1 - *x0) * z);
547		}
548
549		Ok(())
550	}
551
552	fn compute_composite(
553		&mut self,
554		inputs: &SlicesBatch<<Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>>,
555		output: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
556		composition: &ArithCircuitPoly<T::B128>,
557	) -> Result<(), Error> {
558		if inputs.row_len() != output.len() {
559			return Err(Error::InputValidation("inputs and output must be the same length".into()));
560		}
561
562		if CompositionPoly::<P>::n_vars(composition) != inputs.n_rows() {
563			return Err(Error::InputValidation("composition not match with inputs".into()));
564		}
565
566		let rows = inputs
567			.iter()
568			.map(|slice| slice.as_slice())
569			.collect::<Vec<_>>();
570
571		let log_chunks = get_log_max_threads() + 1;
572
573		let chunk_size = (output.len() >> log_chunks).max(1);
574
575		let packed_row_len = checked_int_div(inputs.row_len(), P::WIDTH);
576
577		let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, packed_row_len) };
578
579		output
580			.as_slice_mut()
581			.par_chunks_mut(chunk_size)
582			.enumerate()
583			.for_each(|(chunk_idx, output_chunk)| {
584				let offset = chunk_idx * chunk_size;
585				let rows = rows_batch.columns_subrange(offset..offset + chunk_size);
586
587				composition
588					.batch_evaluate(&rows, output_chunk)
589					.expect("dimensions are correct");
590			});
591
592		Ok(())
593	}
594
595	fn join<Out1: Send, Out2: Send>(
596		&mut self,
597		op1: impl Send + FnOnce(&mut Self) -> Result<Out1, Error>,
598		op2: impl Send + FnOnce(&mut Self) -> Result<Out2, Error>,
599	) -> Result<(Out1, Out2), Error> {
600		let (out1, out2) =
601			binius_maybe_rayon::join(|| op1(&mut self.clone()), || op2(&mut self.clone()));
602
603		Ok((out1?, out2?))
604	}
605
606	fn map<Out: Send, I: ExactSizeIterator<Item: Send> + Send>(
607		&mut self,
608		iter: I,
609		map: impl Sync + Fn(&mut Self, I::Item) -> Result<Out, Error>,
610	) -> Result<Vec<Out>, Error> {
611		// `par_bridge` doesn't preserve the order of the items in the iterator,
612		// so we need to enumerate the items and sort them back after processing.
613		let mut result = iter
614			.enumerate()
615			.par_bridge()
616			.map(|(index, item)| (index, map(&mut self.clone(), item)))
617			.collect::<Vec<_>>();
618		result.sort_unstable_by_key(|(index, _)| *index);
619
620		result.into_iter().map(|(_, out)| out).collect()
621	}
622
623	fn pairwise_product_reduce(
624		&mut self,
625		_input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
626		_round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
627	) -> Result<(), Error> {
628		// TODO(CRY-490)
629		todo!()
630	}
631}
632
633/// In case when `P1` and `P2` are the same type, this function performs the extrapolation
634/// using the byte-sliced representation of the packed field elements.
635///
636/// `P2` is supposed to be one of the following types: `PackedBinaryField1x128b`,
637/// `PackedBinaryField2x128b` or `PackedBinaryField4x128b`.
638#[inline(always)]
639fn try_extrapolate_line_byte_sliced<P1, P2>(
640	evals_0: &mut [P1],
641	evals_1: &[P1],
642	z: P1::Scalar,
643) -> bool
644where
645	P1: PackedField,
646	P2: PackedField<Scalar = BinaryField128b> + WithUnderlier,
647	P2::Underlier: UnderlierWithBitOps
648		+ PackScalar<BinaryField128b, Packed = P2>
649		+ PackScalar<AESTowerField128b>
650		+ PackScalar<BinaryField8b>
651		+ PackScalar<AESTowerField8b>
652		+ From<u8>
653		+ Pod,
654	u8: NumCast<P2::Underlier>,
655	ByteSlicedUnderlier<P2::Underlier, 16>: PackScalar<AESTowerField128b, Packed: Pod>,
656	PackedType<P2::Underlier, BinaryField8b>:
657		PackedTransformationFactory<PackedType<P2::Underlier, AESTowerField8b>>,
658	PackedType<P2::Underlier, AESTowerField8b>:
659		PackedTransformationFactory<PackedType<P2::Underlier, BinaryField8b>>,
660{
661	if TypeId::of::<P1>() == TypeId::of::<P2>() {
662		// Safety: The transmute calls are safe because source and destination types are the same.
663		extrapolate_line_byte_sliced::<P2::Underlier>(
664			unsafe { transmute::<&mut [P1], &mut [P2]>(evals_0) },
665			unsafe { transmute::<&[P1], &[P2]>(evals_1) },
666			*unsafe { transmute::<&P1::Scalar, &BinaryField128b>(&z) },
667		);
668
669		true
670	} else {
671		false
672	}
673}
674
675// Extrapolate line function that converts packed field elements to byte-sliced representation and
676// back.
677fn extrapolate_line_byte_sliced<Underlier>(
678	evals_0: &mut [PackedType<Underlier, BinaryField128b>],
679	evals_1: &[PackedType<Underlier, BinaryField128b>],
680	z: BinaryField128b,
681) where
682	Underlier: UnderlierWithBitOps
683		+ PackScalar<BinaryField128b>
684		+ PackScalar<AESTowerField128b>
685		+ PackScalar<BinaryField8b>
686		+ PackScalar<AESTowerField8b>
687		+ From<u8>
688		+ Pod,
689	u8: NumCast<Underlier>,
690	ByteSlicedUnderlier<Underlier, 16>: PackScalar<AESTowerField128b, Packed: Pod>,
691	PackedType<Underlier, BinaryField8b>:
692		PackedTransformationFactory<PackedType<Underlier, AESTowerField8b>>,
693	PackedType<Underlier, AESTowerField8b>:
694		PackedTransformationFactory<PackedType<Underlier, BinaryField8b>>,
695{
696	let fwd_transform = make_binary_to_aes_packed_transformer::<
697		PackedType<Underlier, BinaryField128b>,
698		PackedType<Underlier, AESTowerField128b>,
699	>();
700	let inv_transform = make_aes_to_binary_packed_transformer::<
701		PackedType<Underlier, AESTowerField128b>,
702		PackedType<Underlier, BinaryField128b>,
703	>();
704
705	// Process the chunks that have the size of a full byte-sliced packed field.
706	const BYTES_COUNT: usize = 16;
707	let byte_sliced_z =
708		PackedType::<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>::broadcast(z.into());
709	evals_0
710		.par_chunks_exact_mut(BYTES_COUNT)
711		.zip(evals_1.par_chunks_exact(BYTES_COUNT))
712		.for_each(|(x0, x1)| {
713			// Transform x0 to byte-sliced representation
714			let mut x0_aes =
715				[MaybeUninit::<PackedType<Underlier, AESTowerField128b>>::uninit(); BYTES_COUNT];
716			for (x0_aes, x0) in x0_aes.iter_mut().zip(x0.iter()) {
717				_ = *x0_aes.write(fwd_transform.transform(x0));
718			}
719			// Safety: all array elements are initialized at this moment
720			let x0_aes = unsafe {
721				transmute::<
722					&mut [MaybeUninit<PackedType<Underlier, AESTowerField128b>>; BYTES_COUNT],
723					&mut [Underlier; BYTES_COUNT],
724				>(&mut x0_aes)
725			};
726			Underlier::transpose_bytes_to_byte_sliced::<TowerLevel16>(x0_aes);
727
728			// Transform x1 to byte-sliced representation
729			let mut x1_aes =
730				[MaybeUninit::<PackedType<Underlier, AESTowerField128b>>::uninit(); BYTES_COUNT];
731			for (x1_aes, x1) in x1_aes.iter_mut().zip(x1.iter()) {
732				_ = *x1_aes.write(fwd_transform.transform(x1));
733			}
734			// Safety: all array elements are initialized at this moment
735			let x1_aes = unsafe {
736				transmute::<
737					&mut [MaybeUninit<PackedType<Underlier, AESTowerField128b>>; BYTES_COUNT],
738					&mut [Underlier; BYTES_COUNT],
739				>(&mut x1_aes)
740			};
741			Underlier::transpose_bytes_to_byte_sliced::<TowerLevel16>(x1_aes);
742
743			// Perform the extrapolation in the byte-sliced representation
744			{
745				let x0_bytes_sliced = bytemuck::must_cast_mut::<
746					_,
747					PackedType<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>,
748				>(x0_aes);
749				let x1_bytes_sliced = bytemuck::must_cast_ref::<
750					_,
751					PackedType<ByteSlicedUnderlier<Underlier, 16>, AESTowerField128b>,
752				>(x1_aes);
753
754				*x0_bytes_sliced += (*x1_bytes_sliced - *x0_bytes_sliced) * byte_sliced_z;
755			}
756
757			// Transform x0 back to the original packed representation
758			Underlier::transpose_bytes_from_byte_sliced::<TowerLevel16>(x0_aes);
759			for (x0, x0_aes) in x0.iter_mut().zip(x0_aes.iter()) {
760				*x0 = inv_transform.transform(
761					PackedType::<Underlier, AESTowerField128b>::from_underlier_ref(x0_aes),
762				);
763			}
764		});
765
766	// Process the remainder
767	let packed_z = PackedType::<Underlier, BinaryField128b>::broadcast(z);
768	for (x0, x1) in evals_0
769		.chunks_exact_mut(BYTES_COUNT)
770		.into_remainder()
771		.iter_mut()
772		.zip(evals_1.chunks_exact(BYTES_COUNT).remainder())
773	{
774		*x0 += (*x1 - *x0) * packed_z;
775	}
776}
777
778#[derive(Debug)]
779pub struct FastKernelExecutor<T, P>(PhantomData<(T, P)>);
780
781impl<T, P> Default for FastKernelExecutor<T, P> {
782	fn default() -> Self {
783		Self(PhantomData)
784	}
785}
786
787impl<T: TowerFamily, P: PackedTop<T>> KernelExecutor<T::B128> for FastKernelExecutor<T, P> {
788	type Mem = PackedMemory<P>;
789	type Value = T::B128;
790	type ExprEval = ArithCircuitPoly<T::B128>;
791
792	#[inline(always)]
793	fn decl_value(&mut self, init: T::B128) -> Result<Self::Value, Error> {
794		Ok(init)
795	}
796
797	fn sum_composition_evals(
798		&mut self,
799		inputs: &SlicesBatch<<Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>>,
800		composition: &Self::ExprEval,
801		batch_coeff: T::B128,
802		accumulator: &mut Self::Value,
803	) -> Result<(), Error> {
804		// The batch size is chosen to balance the amount of additional memory needed
805		// for the each operation and to minimize the call overhead.
806		// The current value is chosen based on the intuition and may be changed in the future
807		// based on the performance measurements.
808		const BATCH_SIZE: usize = 64;
809
810		let rows = inputs
811			.iter()
812			.map(|slice| slice.as_slice())
813			.collect::<Vec<_>>();
814		if inputs.row_len() >= P::WIDTH {
815			let packed_row_len = checked_int_div(inputs.row_len(), P::WIDTH);
816
817			// Safety: `rows` is guaranteed to be valid as all slices have the same length
818			// (this is guaranteed by the `SlicesBatch` struct).
819			let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, packed_row_len) };
820			let mut result = P::zero();
821			let mut output = [P::zero(); BATCH_SIZE];
822			for offset in (0..packed_row_len).step_by(BATCH_SIZE) {
823				let batch_size = packed_row_len.saturating_sub(offset).min(BATCH_SIZE);
824				let rows = rows_batch.columns_subrange(offset..offset + batch_size);
825				composition
826					.batch_evaluate(&rows, &mut output[..batch_size])
827					.expect("dimensions are correct");
828
829				result += output[..batch_size].iter().copied().sum::<P>();
830			}
831
832			*accumulator += batch_coeff * result.into_iter().sum::<T::B128>();
833		} else {
834			let rows_batch = unsafe { RowsBatchRef::new_unchecked(&rows, 1) };
835
836			let mut output = P::zero();
837			composition
838				.batch_evaluate(&rows_batch, slice::from_mut(&mut output))
839				.expect("dimensions are correct");
840
841			*accumulator +=
842				batch_coeff * output.into_iter().take(inputs.row_len()).sum::<T::B128>();
843		}
844
845		Ok(())
846	}
847
848	fn add(
849		&mut self,
850		log_len: usize,
851		src1: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
852		src2: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
853		dst: &mut <Self::Mem as ComputeMemory<T::B128>>::FSliceMut<'_>,
854	) -> Result<(), Error> {
855		if src1.len() != 1 << log_len {
856			return Err(Error::InputValidation(
857				"src1 length must be equal to 2^log_len".to_string(),
858			));
859		}
860		if src2.len() != 1 << log_len {
861			return Err(Error::InputValidation(
862				"src2 length must be equal to 2^log_len".to_string(),
863			));
864		}
865		if dst.len() != 1 << log_len {
866			return Err(Error::InputValidation(
867				"dst length must be equal to 2^log_len".to_string(),
868			));
869		}
870
871		for (dst_i, &src1_i, &src2_i) in
872			izip!(dst.as_slice_mut().iter_mut(), src1.as_slice(), src2.as_slice())
873		{
874			*dst_i = src1_i + src2_i;
875		}
876
877		Ok(())
878	}
879
880	fn add_assign(
881		&mut self,
882		log_len: usize,
883		src: <Self::Mem as ComputeMemory<T::B128>>::FSlice<'_>,
884		dst: &mut <Self::Mem as ComputeMemory<T::B128>>::FSliceMut<'_>,
885	) -> Result<(), Error> {
886		if src.len() != 1 << log_len {
887			return Err(Error::InputValidation(
888				"src1 length must be equal to 2^log_len".to_string(),
889			));
890		}
891		if dst.len() != 1 << log_len {
892			return Err(Error::InputValidation(
893				"dst length must be equal to 2^log_len".to_string(),
894			));
895		}
896
897		for (dst_i, &src_i) in zip(dst.as_slice_mut().iter_mut(), src.as_slice()) {
898			*dst_i += src_i;
899		}
900
901		Ok(())
902	}
903}
904
905pub struct FastCpuLayerHolder<T: TowerFamily, P: PackedTop<T>> {
906	layer: FastCpuLayer<T, P>,
907	host_mem: Vec<T::B128>,
908	dev_mem: Vec<P>,
909}
910
911impl<T: TowerFamily, P: PackedTop<T>> FastCpuLayerHolder<T, P> {
912	pub fn new(host_mem_size: usize, dev_mem_size: usize) -> Self {
913		let layer = FastCpuLayer::default();
914		let host_mem = vec![T::B128::zero(); host_mem_size];
915		let dev_mem = vec![P::zero(); (dev_mem_size >> P::LOG_WIDTH).max(1)];
916
917		Self {
918			layer,
919			host_mem,
920			dev_mem,
921		}
922	}
923}
924
925impl<T: TowerFamily, P: PackedTop<T>> ComputeHolder<T::B128, FastCpuLayer<T, P>>
926	for FastCpuLayerHolder<T, P>
927{
928	type HostComputeAllocator<'a> = HostBumpAllocator<'a, T::B128>;
929	type DeviceComputeAllocator<'a> =
930		BumpAllocator<'a, T::B128, <FastCpuLayer<T, P> as ComputeLayer<T::B128>>::DevMem>;
931
932	fn to_data<'a, 'b>(
933		&'a mut self,
934	) -> ComputeData<
935		'a,
936		T::B128,
937		FastCpuLayer<T, P>,
938		Self::HostComputeAllocator<'b>,
939		Self::DeviceComputeAllocator<'b>,
940	>
941	where
942		'a: 'b,
943	{
944		ComputeData::new(
945			&self.layer,
946			BumpAllocator::new(self.host_mem.as_mut_slice()),
947			BumpAllocator::new(PackedMemorySliceMut::new_slice(&mut self.dev_mem)),
948		)
949	}
950}