binius_math/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7	arch::{ArchOptimal, OptimalUnderlier},
8	byte_iteration::{
9		can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes,
10		ByteIteratorCallback, PackedSlice,
11	},
12	packed::{
13		get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
14	},
15	underlier::{UnderlierWithBitOps, WithUnderlier},
16	AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField,
17	Field, PackedField,
18};
19use binius_utils::bail;
20use bytemuck::fill_zeroes;
21use itertools::izip;
22use lazy_static::lazy_static;
23use stackalloc::helpers::slice_assume_init_mut;
24
25use crate::Error;
26
27/// Execute the right fold operation.
28///
29/// Every consequent `1 << log_query_size` scalar values are dot-producted with the corresponding
30/// query elements. The result is stored in the `output` slice of packed values.
31///
32/// Please note that this method is single threaded. Currently we always have some
33/// parallelism above this level, so it's not a problem.
34pub fn fold_right<P, PE>(
35	evals: &[P],
36	log_evals_size: usize,
37	query: &[PE],
38	log_query_size: usize,
39	out: &mut [PE],
40) -> Result<(), Error>
41where
42	P: PackedField,
43	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
44{
45	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
46
47	// Try execute the optimized version for 1-bit values if possible
48	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
49		&& fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
50	{
51		return Ok(());
52	}
53
54	// Use linear interpolation for single variable multilinear queries.
55	let is_lerp = log_query_size == 1
56		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
57
58	if is_lerp {
59		let lerp_query = get_packed_slice(query, 1);
60		fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
61	} else {
62		fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
63	}
64
65	Ok(())
66}
67
68/// Execute the left fold operation.
69///
70/// evals is treated as a matrix with `1 << log_query_size` rows and each column is dot-producted
71/// with the corresponding query element. The results is written to the `output` slice of packed values.
72/// If the function returns `Ok(())`, then `out` can be safely interpreted as initialized.
73///
74/// Please note that this method is single threaded. Currently we always have some
75/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
76/// use more efficient optimizations for special cases. If we ever need a parallel version of this
77/// function, we can implement it separately.
78pub fn fold_left<P, PE>(
79	evals: &[P],
80	log_evals_size: usize,
81	query: &[PE],
82	log_query_size: usize,
83	out: &mut [MaybeUninit<PE>],
84) -> Result<(), Error>
85where
86	P: PackedField,
87	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
88{
89	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
90
91	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
92		&& fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
93	{
94		return Ok(());
95	}
96
97	// Use linear interpolation for single variable multilinear queries.
98	// Unlike right folds, left folds are often used with packed fields which are not indexable
99	// and have quite slow scalar indexing methods. For now, we specialize the practically important
100	// case of single variable fold when PE == P.
101	let is_lerp = log_query_size == 1
102		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
103		&& TypeId::of::<P>() == TypeId::of::<PE>();
104
105	if is_lerp {
106		let lerp_query = get_packed_slice(query, 1);
107		// Safety: P == PE checked above.
108		let out_p =
109			unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
110
111		let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
112		fold_left_lerp(
113			evals,
114			1 << log_evals_size,
115			P::Scalar::ZERO,
116			log_evals_size,
117			lerp_query_p,
118			out_p,
119		)?;
120	} else {
121		fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
122	}
123
124	Ok(())
125}
126
127/// Execute the middle fold operation.
128///
129/// Every consequent `1 << start_index` scalar values are considered as a row. Then each column
130/// is dot-producted in chunks (of size `1 << log_query_size`) with the `query` slice.
131/// The results are written to the `output`.
132///
133/// Please note that this method is single threaded. Currently we always have some
134/// parallelism above this level, so it's not a problem.
135pub fn fold_middle<P, PE>(
136	evals: &[P],
137	log_evals_size: usize,
138	query: &[PE],
139	log_query_size: usize,
140	start_index: usize,
141	out: &mut [MaybeUninit<PE>],
142) -> Result<(), Error>
143where
144	P: PackedField,
145	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
146{
147	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
148
149	if log_evals_size < log_query_size + start_index {
150		bail!(Error::IncorrectStartIndex {
151			expected: log_evals_size
152		});
153	}
154
155	let lower_indices_size = 1 << start_index;
156	let new_n_vars = log_evals_size - log_query_size;
157
158	out.iter_mut()
159		.enumerate()
160		.for_each(|(outer_index, out_val)| {
161			let mut res = PE::default();
162			// Index of the Scalar at the start of the current Packed element in `out`.
163			let outer_word_start = outer_index << PE::LOG_WIDTH;
164			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
165				// Index of the current Scalar in `out`.
166				let outer_scalar_index = outer_word_start + inner_index;
167				// Column index, within the reshaped `evals`, where we start the dot-product with the query elements.
168				let inner_i = outer_scalar_index % lower_indices_size;
169				// Row index.
170				let inner_j = outer_scalar_index / lower_indices_size;
171				res.set(
172					inner_index,
173					PackedField::iter_slice(query)
174						.take(1 << log_query_size)
175						.enumerate()
176						.map(|(query_index, basis_eval)| {
177							// The reshaped `evals` "matrix" is subdivided in chunks of size
178							// `1 << (start_index + log_query_size)` for the dot-product.
179							let offset = inner_j << (start_index + log_query_size) | inner_i;
180							let eval_index = offset + (query_index << start_index);
181							let subpoly_eval_i = get_packed_slice(evals, eval_index);
182							basis_eval * subpoly_eval_i
183						})
184						.sum(),
185				);
186			}
187			out_val.write(res);
188		});
189
190	Ok(())
191}
192
193#[inline]
194fn check_fold_arguments<P, PE, POut>(
195	evals: &[P],
196	log_evals_size: usize,
197	query: &[PE],
198	log_query_size: usize,
199	out: &[POut],
200) -> Result<(), Error>
201where
202	P: PackedField,
203	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
204{
205	if log_evals_size < log_query_size {
206		bail!(Error::IncorrectQuerySize {
207			expected: log_evals_size
208		});
209	}
210
211	if P::WIDTH * evals.len() < 1 << log_evals_size {
212		bail!(Error::IncorrectArgumentLength {
213			arg: "evals".into(),
214			expected: 1 << log_evals_size
215		});
216	}
217
218	if PE::WIDTH * query.len() < 1 << log_query_size {
219		bail!(Error::IncorrectArgumentLength {
220			arg: "query".into(),
221			expected: 1 << log_query_size
222		});
223	}
224
225	if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
226		bail!(Error::IncorrectOutputPolynomialSize {
227			expected: 1 << (log_evals_size - log_query_size)
228		});
229	}
230
231	Ok(())
232}
233
234#[inline]
235fn check_right_lerp_fold_arguments<P, PE, POut>(
236	evals: &[P],
237	evals_size: usize,
238	out: &[POut],
239) -> Result<(), Error>
240where
241	P: PackedField,
242	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
243{
244	if P::WIDTH * evals.len() < evals_size {
245		bail!(Error::IncorrectArgumentLength {
246			arg: "evals".into(),
247			expected: evals_size
248		});
249	}
250
251	if PE::WIDTH * out.len() * 2 < evals_size {
252		bail!(Error::IncorrectOutputPolynomialSize {
253			expected: evals_size.div_ceil(2)
254		});
255	}
256
257	Ok(())
258}
259
260#[inline]
261fn check_left_lerp_fold_arguments<P, PE, POut>(
262	evals: &[P],
263	non_const_prefix: usize,
264	log_evals_size: usize,
265	out: &[POut],
266) -> Result<(), Error>
267where
268	P: PackedField,
269	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
270{
271	if log_evals_size == 0 {
272		bail!(Error::IncorrectQuerySize { expected: 1 });
273	}
274
275	if non_const_prefix > 1 << log_evals_size {
276		bail!(Error::IncorrectNonzeroScalarPrefix {
277			expected: 1 << log_evals_size,
278		});
279	}
280
281	if P::WIDTH * evals.len() < non_const_prefix {
282		bail!(Error::IncorrectArgumentLength {
283			arg: "evals".into(),
284			expected: non_const_prefix,
285		});
286	}
287
288	let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
289
290	if PE::WIDTH * out.len() < folded_non_const_prefix {
291		bail!(Error::IncorrectOutputPolynomialSize {
292			expected: folded_non_const_prefix,
293		});
294	}
295
296	Ok(())
297}
298
299/// Optimized version for 1-bit values with query size 0-2
300fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
301	evals: &[P],
302	query: &[PE],
303	out: &mut [PE],
304) -> bool
305where
306	P: PackedField,
307	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
308{
309	if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
310		return false;
311	}
312
313	// Cache the table for a single evaluation
314	let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
315		.map(|i| {
316			let mut result = PE::Scalar::ZERO;
317			for j in 0..1 << LOG_QUERY_SIZE {
318				if i >> j & 1 == 1 {
319					result += get_packed_slice(query, j);
320				}
321			}
322			result
323		})
324		.collect::<Vec<_>>();
325
326	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
327		out: &'a mut [PE],
328		cached_table: &'a [PE::Scalar],
329	}
330
331	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
332		for Callback<'_, PE, LOG_QUERY_SIZE>
333	{
334		#[inline(always)]
335		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
336			let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
337			let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
338			let mut current_index = 0;
339			for byte in iterator {
340				for k in 0..values_in_byte {
341					let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
342					// Safety: `i` is less than `chunk_size`
343					unsafe {
344						set_packed_slice_unchecked(
345							self.out,
346							current_index + k,
347							self.cached_table[index as usize],
348						);
349					}
350				}
351
352				current_index += values_in_byte;
353			}
354		}
355	}
356
357	let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
358		out,
359		cached_table: &cached_table,
360	};
361
362	iterate_bytes(evals, &mut callback);
363
364	true
365}
366
367/// Optimized version for 1-bit values with medium log query size (3-6)
368fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
369	evals: &[P],
370	query: &[PE],
371	out: &mut [PE],
372) -> bool
373where
374	P: PackedField,
375	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
376{
377	if LOG_QUERY_SIZE < 3 {
378		return false;
379	}
380
381	if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
382		return false;
383	}
384
385	let cached_tables =
386		create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE));
387
388	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
389		out: &'a mut [PE],
390		cached_tables: &'a [PE::Scalar],
391	}
392
393	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
394		for Callback<'_, PE, LOG_QUERY_SIZE>
395	{
396		#[inline(always)]
397		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
398			let log_tables_count = LOG_QUERY_SIZE - 3;
399			let tables_count = 1 << log_tables_count;
400			let mut current_index = 0;
401			let mut current_table = 0;
402			let mut current_value = PE::Scalar::ZERO;
403			for byte in iterator {
404				current_value += self.cached_tables[(current_table << 8) + byte as usize];
405				current_table += 1;
406
407				if current_table == tables_count {
408					// Safety: `i` is less than `chunk_size`
409					unsafe {
410						set_packed_slice_unchecked(self.out, current_index, current_value);
411					}
412					current_index += 1;
413					current_table = 0;
414					current_value = PE::Scalar::ZERO;
415				}
416			}
417		}
418	}
419
420	let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
421		out,
422		cached_tables: &cached_tables,
423	};
424
425	iterate_bytes(evals, &mut callback);
426
427	true
428}
429
430/// Try run optimized version for 1-bit values.
431/// Returns true in case when the optimized calculation was performed.
432/// Otherwise, returns false and the fallback should be used.
433fn fold_right_1bit_evals<P, PE>(
434	evals: &[P],
435	log_evals_size: usize,
436	query: &[PE],
437	log_query_size: usize,
438	out: &mut [PE],
439) -> bool
440where
441	P: PackedField,
442	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
443{
444	if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
445		return false;
446	}
447
448	// We pass log_query_size_ as a generic parameter because that allows a compiler producing
449	// more efficient code in the tight loops inside the functions.
450	match log_query_size {
451		0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
452		1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
453		2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
454		3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
455		4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
456		5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
457		6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
458		7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
459		_ => false,
460	}
461}
462
463/// Specialized implementation for a single parameter right fold using linear interpolation
464/// instead of tensor expansion resulting in  a single multiplication instead of two:
465///   f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w).
466///
467/// The same approach may be generalized to higher variable counts, with diminishing returns.
468///
469/// Please note that this method is single threaded. Currently we always have some
470/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
471/// use more efficient optimizations for special cases. If we ever need a parallel version of this
472/// function, we can implement it separately.
473pub fn fold_right_lerp<P, PE>(
474	evals: &[P],
475	evals_size: usize,
476	lerp_query: PE::Scalar,
477	suffix_eval: P::Scalar,
478	out: &mut [PE],
479) -> Result<(), Error>
480where
481	P: PackedField,
482	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
483{
484	check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
485
486	let folded_evals_size = evals_size >> 1;
487	out[..folded_evals_size.div_ceil(PE::WIDTH)]
488		.iter_mut()
489		.enumerate()
490		.for_each(|(i, packed_result_eval)| {
491			for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
492				let index = (i << PE::LOG_WIDTH) | j;
493
494				let (eval0, eval1) = unsafe {
495					(
496						get_packed_slice_unchecked(evals, index << 1),
497						get_packed_slice_unchecked(evals, (index << 1) | 1),
498					)
499				};
500
501				let result_eval =
502					PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
503
504				// Safety: `j` < `PE::WIDTH`
505				unsafe {
506					packed_result_eval.set_unchecked(j, result_eval);
507				}
508			}
509		});
510
511	if evals_size % 2 == 1 {
512		let eval0 = get_packed_slice(evals, folded_evals_size << 1);
513		set_packed_slice(
514			out,
515			folded_evals_size,
516			PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
517		);
518	}
519
520	Ok(())
521}
522
523/// Left linear interpolation (lerp, single variable) fold
524///
525/// Please note that this method is single threaded. Currently we always have some
526/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
527/// use more efficient optimizations for special cases. If we ever need a parallel version of this
528/// function, we can implement it separately.
529///
530/// Also note that left folds are often intended to be used with non-indexable packed fields that
531/// have inefficient scalar access; fully generic handling of all interesting cases that can leverage
532/// spread multiplication requires dynamically checking the `PackedExtension` relations, so for now we
533/// just handle the simplest yet important case of a single variable left fold in packed field P with
534/// a lerp query of its scalar (and not a nontrivial extension field!).
535pub fn fold_left_lerp<P>(
536	evals: &[P],
537	non_const_prefix: usize,
538	suffix_eval: P::Scalar,
539	log_evals_size: usize,
540	lerp_query: P::Scalar,
541	out: &mut [MaybeUninit<P>],
542) -> Result<(), Error>
543where
544	P: PackedField,
545{
546	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
547
548	if log_evals_size > P::LOG_WIDTH {
549		let pivot = non_const_prefix
550			.saturating_sub(1 << (log_evals_size - 1))
551			.div_ceil(P::WIDTH);
552
553		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
554		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
555
556		if pivot > 0 {
557			let (evals_0, evals_1) = evals.split_at(packed_len);
558			for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
559				out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
560			}
561		}
562
563		let broadcast_suffix_eval = P::broadcast(suffix_eval);
564		for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
565			out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
566		}
567
568		for out in &mut out[upper_bound..] {
569			out.write(P::zero());
570		}
571	} else if non_const_prefix > 0 {
572		let only_packed = *evals.first().expect("log_evals_size > 0");
573		let mut folded = P::zero();
574
575		for i in 0..1 << (log_evals_size - 1) {
576			let eval_0 = only_packed.get(i);
577			let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
578			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
579		}
580
581		out.first_mut().expect("log_evals_size > 0").write(folded);
582	}
583
584	Ok(())
585}
586
587/// Inplace left linear interpolation (lerp, single variable) fold
588///
589/// Please note that this method is single threaded. Currently we always have some
590/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
591/// use more efficient optimizations for special cases. If we ever need a parallel version of this
592/// function, we can implement it separately.
593pub fn fold_left_lerp_inplace<P>(
594	evals: &mut Vec<P>,
595	non_const_prefix: usize,
596	suffix_eval: P::Scalar,
597	log_evals_size: usize,
598	lerp_query: P::Scalar,
599) -> Result<(), Error>
600where
601	P: PackedField,
602{
603	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
604
605	if log_evals_size > P::LOG_WIDTH {
606		let pivot = non_const_prefix
607			.saturating_sub(1 << (log_evals_size - 1))
608			.div_ceil(P::WIDTH);
609
610		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
611		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
612
613		if pivot > 0 {
614			let (evals_0, evals_1) = evals.split_at_mut(packed_len);
615			for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
616				*eval_0 += (*eval_1 - *eval_0) * lerp_query;
617			}
618		}
619
620		let broadcast_suffix_eval = P::broadcast(suffix_eval);
621		for eval in &mut evals[pivot..upper_bound] {
622			*eval += (broadcast_suffix_eval - *eval) * lerp_query;
623		}
624
625		evals.truncate(upper_bound);
626	} else if non_const_prefix > 0 {
627		let only_packed = evals.first_mut().expect("log_evals_size > 0");
628		let mut folded = P::zero();
629		let half_size = 1 << (log_evals_size - 1);
630
631		for i in 0..half_size {
632			let eval_0 = only_packed.get(i);
633			let eval_1 = only_packed.get(i | half_size);
634			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
635		}
636
637		*only_packed = folded;
638	}
639
640	Ok(())
641}
642
643/// Fallback implementation for fold that can be executed for any field types and sizes.
644fn fold_right_fallback<P, PE>(
645	evals: &[P],
646	log_evals_size: usize,
647	query: &[PE],
648	log_query_size: usize,
649	out: &mut [PE],
650) where
651	P: PackedField,
652	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
653{
654	for (k, packed_result_eval) in out.iter_mut().enumerate() {
655		for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
656			let index = (k << PE::LOG_WIDTH) | j;
657
658			let offset = index << log_query_size;
659
660			let mut result_eval = PE::Scalar::ZERO;
661			for (t, query_expansion) in PackedField::iter_slice(query)
662				.take(1 << log_query_size)
663				.enumerate()
664			{
665				result_eval += query_expansion * get_packed_slice(evals, t + offset);
666			}
667
668			// Safety: `j` < `PE::WIDTH`
669			unsafe {
670				packed_result_eval.set_unchecked(j, result_eval);
671			}
672		}
673	}
674}
675
676type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
677
678#[inline(always)]
679fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
680	TypeId::of::<ArchOptimalType<F>>()
681}
682
683/// Use optimized algorithm for 1-bit evaluations with 128-bit query values packed with the optimal underlier.
684/// Returns true if the optimized algorithm was used. In this case `out` can be safely interpreted
685/// as initialized.
686///
687/// We could potentially this of specializing for other query fields, but that would require
688/// separate implementations.
689fn fold_left_1b_128b<P, PE>(
690	evals: &[P],
691	log_evals_size: usize,
692	query: &[PE],
693	log_query_size: usize,
694	out: &mut [MaybeUninit<PE>],
695) -> bool
696where
697	P: PackedField,
698	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
699{
700	if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
701		return false;
702	}
703
704	let log_row_size = log_evals_size - log_query_size;
705	if log_row_size < 3 {
706		return false;
707	}
708
709	if PE::LOG_WIDTH > 3 {
710		return false;
711	}
712
713	// Safety: the cast is safe because the type is checked by the previous if statement
714	let evals_u8: &[u8] = unsafe {
715		std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
716	};
717
718	// Try to run the optimized version for the specific 128-bit field type.
719	// This is a workaround for the lack of specialization in Rust.
720	#[inline]
721	fn try_run_specialization<PE, F>(
722		lookup_table: &[OptimalUnderlier],
723		evals_u8: &[u8],
724		log_evals_size: usize,
725		query: &[PE],
726		log_query_size: usize,
727		out: &mut [MaybeUninit<PE>],
728	) -> bool
729	where
730		PE: PackedField,
731		F: ArchOptimal,
732	{
733		if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
734			let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
735			let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
736
737			fold_left_1b_128b_impl(
738				lookup_table,
739				evals_u8,
740				log_evals_size,
741				query,
742				log_query_size,
743				out,
744			);
745			true
746		} else {
747			false
748		}
749	}
750
751	let lookup_table = &*LOOKUP_TABLE;
752	try_run_specialization::<_, BinaryField128b>(
753		lookup_table,
754		evals_u8,
755		log_evals_size,
756		query,
757		log_query_size,
758		out,
759	) || try_run_specialization::<_, AESTowerField128b>(
760		lookup_table,
761		evals_u8,
762		log_evals_size,
763		query,
764		log_query_size,
765		out,
766	) || try_run_specialization::<_, BinaryField128bPolyval>(
767		lookup_table,
768		evals_u8,
769		log_evals_size,
770		query,
771		log_query_size,
772		out,
773	)
774}
775
776/// Cast slice from unknown type to the known one assuming that the types are the same.
777#[inline(always)]
778fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
779	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
780	// Safety: the cast is safe because the type is checked by the previous if statement
781	unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
782}
783
784/// Cast slice from unknown type to the known one assuming that the types are the same.
785#[inline(always)]
786fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
787	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
788	// Safety: the cast is safe because the type is checked by the previous if statement
789	unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
790}
791
792/// Initialize the lookup table u8 -> [U; 8 / <number of 128-bit elements in U>] where
793/// each bit in `u8` corresponds to a 128-bit element in `U` filled with ones or zeros
794/// depending on the bit value. We use the values as bit masks for fast multiplication
795/// by packed BinaryField1b values.
796fn init_lookup_table_width<U>() -> Vec<U>
797where
798	U: UnderlierWithBitOps + From<u128>,
799{
800	let items_b128 = U::BITS / u128::BITS as usize;
801	assert!(items_b128 <= 8);
802	let items_in_byte = 8 / items_b128;
803
804	let mut result = Vec::with_capacity(256 * items_in_byte);
805	for i in 0..256 {
806		for j in 0..items_in_byte {
807			let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
808			let mut value = U::ZERO;
809			for k in 0..items_b128 {
810				if (bits >> k) & 1 == 1 {
811					unsafe {
812						value.set_subvalue(k, u128::ONES);
813					}
814				}
815			}
816			result.push(value);
817		}
818	}
819
820	result
821}
822
823lazy_static! {
824	static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
825}
826
827#[inline]
828fn fold_left_1b_128b_impl<PE, U>(
829	lookup_table: &[U],
830	evals: &[u8],
831	log_evals_size: usize,
832	query: &[PE],
833	log_query_size: usize,
834	out: &mut [MaybeUninit<PE>],
835) where
836	PE: PackedField + WithUnderlier<Underlier = U>,
837	U: UnderlierWithBitOps,
838{
839	let out = unsafe { slice_assume_init_mut(out) };
840	fill_zeroes(out);
841
842	let items_in_byte = 8 / PE::WIDTH;
843	let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
844	for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
845		let query_val = PE::broadcast(query_val).to_underlier();
846		for (byte_index, byte) in row_bytes.iter().enumerate() {
847			let mask_offset = *byte as usize * items_in_byte;
848			let out_offset = byte_index * items_in_byte;
849			for i in 0..items_in_byte {
850				let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
851				let multiplied = query_val & *mask;
852				let out = unsafe { out.get_unchecked_mut(out_offset + i) };
853				*out += PE::from_underlier(multiplied);
854			}
855		}
856	}
857}
858
859fn fold_left_fallback<P, PE>(
860	evals: &[P],
861	log_evals_size: usize,
862	query: &[PE],
863	log_query_size: usize,
864	out: &mut [MaybeUninit<PE>],
865) where
866	P: PackedField,
867	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
868{
869	let new_n_vars = log_evals_size - log_query_size;
870
871	out.iter_mut()
872		.enumerate()
873		.for_each(|(outer_index, out_val)| {
874			let mut res = PE::default();
875			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
876				res.set(
877					inner_index,
878					PackedField::iter_slice(query)
879						.take(1 << log_query_size)
880						.enumerate()
881						.map(|(query_index, basis_eval)| {
882							let eval_index = (query_index << new_n_vars)
883								| (outer_index << PE::LOG_WIDTH)
884								| inner_index;
885							let subpoly_eval_i = get_packed_slice(evals, eval_index);
886							basis_eval * subpoly_eval_i
887						})
888						.sum(),
889				);
890			}
891
892			out_val.write(res);
893		});
894}
895
896#[cfg(test)]
897mod tests {
898	use std::iter::repeat_with;
899
900	use binius_field::{
901		packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b,
902		PackedBinaryField16x8b, PackedBinaryField32x1b, PackedBinaryField512x1b,
903		PackedBinaryField64x8b,
904	};
905	use rand::{rngs::StdRng, SeedableRng};
906
907	use super::*;
908
909	fn fold_right_reference<P, PE>(
910		evals: &[P],
911		log_evals_size: usize,
912		query: &[PE],
913		log_query_size: usize,
914		out: &mut [PE],
915	) where
916		P: PackedField,
917		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
918	{
919		for i in 0..1 << (log_evals_size - log_query_size) {
920			let mut result = PE::Scalar::ZERO;
921			for j in 0..1 << log_query_size {
922				result +=
923					get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
924			}
925
926			set_packed_slice(out, i, result);
927		}
928	}
929
930	fn check_fold_right<P, PE>(
931		evals: &[P],
932		log_evals_size: usize,
933		query: &[PE],
934		log_query_size: usize,
935	) where
936		P: PackedField,
937		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
938	{
939		let mut reference_out =
940			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
941		let mut out = reference_out.clone();
942
943		fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
944		fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
945
946		for i in 0..1 << (log_evals_size - log_query_size) {
947			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
948		}
949	}
950
951	#[test]
952	fn test_1b_small_poly_query_log_size_0() {
953		let mut rng = StdRng::seed_from_u64(0);
954		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
955		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
956
957		check_fold_right(&evals, 0, &query, 0);
958	}
959
960	#[test]
961	fn test_1b_small_poly_query_log_size_1() {
962		let mut rng = StdRng::seed_from_u64(0);
963		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
964		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
965
966		check_fold_right(&evals, 2, &query, 1);
967	}
968
969	#[test]
970	fn test_1b_small_poly_query_log_size_7() {
971		let mut rng = StdRng::seed_from_u64(0);
972		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
973		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
974
975		check_fold_right(&evals, 7, &query, 7);
976	}
977
978	#[test]
979	fn test_1b_many_evals() {
980		const LOG_EVALS_SIZE: usize = 14;
981		let mut rng = StdRng::seed_from_u64(1);
982		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
983			.take(1 << LOG_EVALS_SIZE)
984			.collect::<Vec<_>>();
985		let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
986			.take(8)
987			.collect::<Vec<_>>();
988
989		for log_query_size in 0..10 {
990			check_fold_right(
991				&evals,
992				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
993				&query,
994				log_query_size,
995			);
996		}
997	}
998
999	#[test]
1000	fn test_8b_small_poly() {
1001		const LOG_EVALS_SIZE: usize = 5;
1002		let mut rng = StdRng::seed_from_u64(0);
1003		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1004			.take(1 << LOG_EVALS_SIZE)
1005			.collect::<Vec<_>>();
1006		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1007			.take(1 << 8)
1008			.collect::<Vec<_>>();
1009
1010		for log_query_size in 0..8 {
1011			check_fold_right(
1012				&evals,
1013				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1014				&query,
1015				log_query_size,
1016			);
1017		}
1018	}
1019
1020	#[test]
1021	fn test_8b_many_evals() {
1022		const LOG_EVALS_SIZE: usize = 13;
1023		let mut rng = StdRng::seed_from_u64(0);
1024		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1025			.take(1 << LOG_EVALS_SIZE)
1026			.collect::<Vec<_>>();
1027		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1028			.take(1 << 8)
1029			.collect::<Vec<_>>();
1030
1031		for log_query_size in 0..8 {
1032			check_fold_right(
1033				&evals,
1034				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1035				&query,
1036				log_query_size,
1037			);
1038		}
1039	}
1040
1041	#[test]
1042	fn test_lower_higher_bits() {
1043		const LOG_EVALS_SIZE: usize = 7;
1044		let mut rng = StdRng::seed_from_u64(0);
1045		// We have four elements of 32 bits each. Overall, that gives us 128 bits.
1046		let evals = [
1047			PackedBinaryField32x1b::random(&mut rng),
1048			PackedBinaryField32x1b::random(&mut rng),
1049			PackedBinaryField32x1b::random(&mut rng),
1050			PackedBinaryField32x1b::random(&mut rng),
1051		];
1052
1053		// We get the lower bits of all elements (`log_query_size` = 1, so only the first
1054		// two bits of `query` are used).
1055		let log_query_size = 1;
1056		let query = [PackedBinaryField32x1b::from(
1057			0b00000000_00000000_00000000_00000001u32,
1058		)];
1059		let mut out = vec![
1060			PackedBinaryField32x1b::zero();
1061			(1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1062		];
1063		out.clear();
1064		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1065			.unwrap();
1066		unsafe {
1067			out.set_len(out.capacity());
1068		}
1069
1070		// Every 16 consecutive bits in `out` should be equal to the lower bits of an element.
1071		for (i, out_val) in out.iter().enumerate() {
1072			let first_lower = (evals[2 * i].0 as u16) as u32;
1073			let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1074			let expected_out = first_lower + (second_lower << 16);
1075			assert!(out_val.0 == expected_out);
1076		}
1077
1078		// We get the higher bits of all elements (`log_query_size` = 1, so only the first
1079		// two bits of `query` are used).
1080		let query = [PackedBinaryField32x1b::from(
1081			0b00000000_00000000_00000000_00000010u32,
1082		)];
1083
1084		out.clear();
1085		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1086			.unwrap();
1087		unsafe {
1088			out.set_len(out.capacity());
1089		}
1090
1091		// Every 16 consecutive bits in `out` should be equal to the higher bits of an element.
1092		for (i, out_val) in out.iter().enumerate() {
1093			let first_higher = evals[2 * i].0 >> 16;
1094			let second_higher = evals[2 * i + 1].0 >> 16;
1095			let expected_out = first_higher + (second_higher << 16);
1096			assert!(out_val.0 == expected_out);
1097		}
1098	}
1099
1100	fn fold_left_reference<P, PE>(
1101		evals: &[P],
1102		log_evals_size: usize,
1103		query: &[PE],
1104		log_query_size: usize,
1105		out: &mut [PE],
1106	) where
1107		P: PackedField,
1108		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1109	{
1110		for i in 0..1 << (log_evals_size - log_query_size) {
1111			let mut result = PE::Scalar::ZERO;
1112			for j in 0..1 << log_query_size {
1113				result += get_packed_slice(query, j)
1114					* get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1115			}
1116
1117			set_packed_slice(out, i, result);
1118		}
1119	}
1120
1121	fn check_fold_left<P, PE>(
1122		evals: &[P],
1123		log_evals_size: usize,
1124		query: &[PE],
1125		log_query_size: usize,
1126	) where
1127		P: PackedField,
1128		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1129	{
1130		let mut reference_out =
1131			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1132
1133		let mut out = reference_out.clone();
1134		out.clear();
1135		fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1136		unsafe {
1137			out.set_len(out.capacity());
1138		}
1139
1140		fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1141
1142		for i in 0..1 << (log_evals_size - log_query_size) {
1143			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1144		}
1145	}
1146
1147	#[test]
1148	fn test_fold_left_1b_small_poly_query_log_size_0() {
1149		let mut rng = StdRng::seed_from_u64(0);
1150		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1151		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1152
1153		check_fold_left(&evals, 0, &query, 0);
1154	}
1155
1156	#[test]
1157	fn test_fold_left_1b_small_poly_query_log_size_1() {
1158		let mut rng = StdRng::seed_from_u64(0);
1159		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1160		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1161
1162		check_fold_left(&evals, 2, &query, 1);
1163	}
1164
1165	#[test]
1166	fn test_fold_left_1b_small_poly_query_log_size_7() {
1167		let mut rng = StdRng::seed_from_u64(0);
1168		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1169		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1170
1171		check_fold_left(&evals, 7, &query, 7);
1172	}
1173
1174	#[test]
1175	fn test_fold_left_1b_many_evals() {
1176		const LOG_EVALS_SIZE: usize = 14;
1177		let mut rng = StdRng::seed_from_u64(1);
1178		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1179			.take(1 << LOG_EVALS_SIZE)
1180			.collect::<Vec<_>>();
1181		let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1182
1183		for log_query_size in 0..10 {
1184			check_fold_left(
1185				&evals,
1186				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1187				&query,
1188				log_query_size,
1189			);
1190		}
1191	}
1192
1193	type B128bOptimal = ArchOptimalType<BinaryField128b>;
1194
1195	#[test]
1196	fn test_fold_left_1b_128b_optimal() {
1197		const LOG_EVALS_SIZE: usize = 14;
1198		let mut rng = StdRng::seed_from_u64(0);
1199		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1200			.take(1 << LOG_EVALS_SIZE)
1201			.collect::<Vec<_>>();
1202		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1203			.take(1 << (10 - B128bOptimal::LOG_WIDTH))
1204			.collect::<Vec<_>>();
1205
1206		for log_query_size in 0..10 {
1207			check_fold_left(
1208				&evals,
1209				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1210				&query,
1211				log_query_size,
1212			);
1213		}
1214	}
1215
1216	#[test]
1217	fn test_fold_left_128b_128b() {
1218		const LOG_EVALS_SIZE: usize = 14;
1219		let mut rng = StdRng::seed_from_u64(0);
1220		let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1221			.take(1 << LOG_EVALS_SIZE)
1222			.collect::<Vec<_>>();
1223		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1224			.take(1 << 10)
1225			.collect::<Vec<_>>();
1226
1227		for log_query_size in 0..10 {
1228			check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1229		}
1230	}
1231
1232	#[test]
1233	fn test_fold_left_lerp_inplace_conforms_reference() {
1234		const LOG_EVALS_SIZE: usize = 14;
1235		let mut rng = StdRng::seed_from_u64(0);
1236		let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1237			.take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1238			.collect::<Vec<_>>();
1239		let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1240		let mut query =
1241			vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1242		set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1243		set_packed_slice(&mut query, 1, lerp_query);
1244
1245		for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1246			let mut out = vec![
1247				MaybeUninit::uninit();
1248				1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1249			];
1250			fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1251			fold_left_lerp_inplace(
1252				&mut evals,
1253				1 << log_evals_size,
1254				Field::ZERO,
1255				log_evals_size,
1256				lerp_query,
1257			)
1258			.unwrap();
1259
1260			for (out, &inplace) in izip!(&out, &evals) {
1261				unsafe {
1262					assert_eq!(out.assume_init(), inplace);
1263				}
1264			}
1265		}
1266	}
1267
1268	#[test]
1269	fn test_check_fold_arguments_valid() {
1270		let evals = vec![PackedBinaryField128x1b::default(); 8];
1271		let query = vec![PackedBinaryField128x1b::default(); 4];
1272		let out = vec![PackedBinaryField128x1b::default(); 4];
1273
1274		// Should pass as query and output sizes are valid
1275		let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1276		assert!(result.is_ok());
1277	}
1278
1279	#[test]
1280	fn test_check_fold_arguments_invalid_query_size() {
1281		let evals = vec![PackedBinaryField128x1b::default(); 8];
1282		let query = vec![PackedBinaryField128x1b::default(); 4];
1283		let out = vec![PackedBinaryField128x1b::default(); 4];
1284
1285		// Should fail as log_query_size > log_evals_size
1286		let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1287		assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1288	}
1289}