binius_math/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7	arch::{ArchOptimal, OptimalUnderlier},
8	byte_iteration::{
9		can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes,
10		ByteIteratorCallback,
11	},
12	packed::{
13		get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
14		PackedSlice,
15	},
16	underlier::{UnderlierWithBitOps, WithUnderlier},
17	AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField,
18	Field, PackedField,
19};
20use binius_utils::bail;
21use bytemuck::fill_zeroes;
22use itertools::izip;
23use lazy_static::lazy_static;
24use stackalloc::helpers::slice_assume_init_mut;
25
26use crate::Error;
27
28pub fn zero_pad<P, PE>(
29	evals: &[P],
30	log_evals_size: usize,
31	log_new_vals: usize,
32	start_index: usize,
33	nonzero_index: usize,
34	out: &mut [PE],
35) -> Result<(), Error>
36where
37	P: PackedField,
38	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
39{
40	if start_index > log_evals_size {
41		bail!(Error::IncorrectStartIndex {
42			expected: log_evals_size
43		});
44	}
45
46	let extra_vars = log_new_vals - log_evals_size;
47	if nonzero_index > 1 << extra_vars {
48		bail!(Error::IncorrectNonZeroIndex {
49			expected: 1 << extra_vars
50		});
51	}
52
53	let block_size = 1 << start_index;
54	let nb_elts_per_row = 1 << (start_index + extra_vars);
55	let length = 1 << start_index;
56
57	let start_index_in_row = nonzero_index * block_size;
58	for (outer_index, packed_result_eval) in out.iter_mut().enumerate() {
59		// Index of the Scalar at the start of the current Packed element in `out`.
60		let outer_word_start = outer_index << PE::LOG_WIDTH;
61		for inner_index in 0..min(PE::WIDTH, 1 << log_evals_size) {
62			// Index of the current Scalar in `out`.
63			let outer_scalar_index = outer_word_start + inner_index;
64			// Column index, within the reshaped `evals`, where we start the dot-product with the
65			// query elements.
66			let inner_col = outer_scalar_index % nb_elts_per_row;
67			// Row index.
68			let inner_row = outer_scalar_index / nb_elts_per_row;
69
70			if inner_col >= start_index_in_row && inner_col < start_index_in_row + block_size {
71				let eval_index = inner_row * length + inner_col - start_index_in_row;
72				let result_eval = get_packed_slice(evals, eval_index).into();
73				unsafe {
74					packed_result_eval.set_unchecked(inner_index, result_eval);
75				}
76			}
77		}
78	}
79	Ok(())
80}
81
82/// Execute the right fold operation.
83///
84/// Every consequent `1 << log_query_size` scalar values are dot-producted with the corresponding
85/// query elements. The result is stored in the `output` slice of packed values.
86///
87/// Please note that this method is single threaded. Currently we always have some
88/// parallelism above this level, so it's not a problem.
89pub fn fold_right<P, PE>(
90	evals: &[P],
91	log_evals_size: usize,
92	query: &[PE],
93	log_query_size: usize,
94	out: &mut [PE],
95) -> Result<(), Error>
96where
97	P: PackedField,
98	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
99{
100	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
101
102	// Try execute the optimized version for 1-bit values if possible
103	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
104		&& fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
105	{
106		return Ok(());
107	}
108
109	// Use linear interpolation for single variable multilinear queries.
110	let is_lerp = log_query_size == 1
111		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
112
113	if is_lerp {
114		let lerp_query = get_packed_slice(query, 1);
115		fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
116	} else {
117		fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
118	}
119
120	Ok(())
121}
122
123/// Execute the left fold operation.
124///
125/// evals is treated as a matrix with `1 << log_query_size` rows and each column is dot-producted
126/// with the corresponding query element. The results is written to the `output` slice of packed
127/// values. If the function returns `Ok(())`, then `out` can be safely interpreted as initialized.
128///
129/// Please note that this method is single threaded. Currently we always have some
130/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
131/// use more efficient optimizations for special cases. If we ever need a parallel version of this
132/// function, we can implement it separately.
133pub fn fold_left<P, PE>(
134	evals: &[P],
135	log_evals_size: usize,
136	query: &[PE],
137	log_query_size: usize,
138	out: &mut [MaybeUninit<PE>],
139) -> Result<(), Error>
140where
141	P: PackedField,
142	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
143{
144	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
145
146	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
147		&& fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
148	{
149		return Ok(());
150	}
151
152	// Use linear interpolation for single variable multilinear queries.
153	// Unlike right folds, left folds are often used with packed fields which are not indexable
154	// and have quite slow scalar indexing methods. For now, we specialize the practically important
155	// case of single variable fold when PE == P.
156	let is_lerp = log_query_size == 1
157		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
158		&& TypeId::of::<P>() == TypeId::of::<PE>();
159
160	if is_lerp {
161		let lerp_query = get_packed_slice(query, 1);
162		// Safety: P == PE checked above.
163		let out_p =
164			unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
165
166		let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
167		fold_left_lerp(
168			evals,
169			1 << log_evals_size,
170			P::Scalar::ZERO,
171			log_evals_size,
172			lerp_query_p,
173			out_p,
174		)?;
175	} else {
176		fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
177	}
178
179	Ok(())
180}
181
182/// Execute the middle fold operation.
183///
184/// Every consequent `1 << start_index` scalar values are considered as a row. Then each column
185/// is dot-producted in chunks (of size `1 << log_query_size`) with the `query` slice.
186/// The results are written to the `output`.
187///
188/// Please note that this method is single threaded. Currently we always have some
189/// parallelism above this level, so it's not a problem.
190pub fn fold_middle<P, PE>(
191	evals: &[P],
192	log_evals_size: usize,
193	query: &[PE],
194	log_query_size: usize,
195	start_index: usize,
196	out: &mut [MaybeUninit<PE>],
197) -> Result<(), Error>
198where
199	P: PackedField,
200	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
201{
202	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
203
204	if log_evals_size < log_query_size + start_index {
205		bail!(Error::IncorrectStartIndex {
206			expected: log_evals_size
207		});
208	}
209
210	let lower_indices_size = 1 << start_index;
211	let new_n_vars = log_evals_size - log_query_size;
212
213	out.iter_mut()
214		.enumerate()
215		.for_each(|(outer_index, out_val)| {
216			let mut res = PE::default();
217			// Index of the Scalar at the start of the current Packed element in `out`.
218			let outer_word_start = outer_index << PE::LOG_WIDTH;
219			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
220				// Index of the current Scalar in `out`.
221				let outer_scalar_index = outer_word_start + inner_index;
222				// Column index, within the reshaped `evals`, where we start the dot-product with
223				// the query elements.
224				let inner_i = outer_scalar_index % lower_indices_size;
225				// Row index.
226				let inner_j = outer_scalar_index / lower_indices_size;
227				res.set(
228					inner_index,
229					PackedField::iter_slice(query)
230						.take(1 << log_query_size)
231						.enumerate()
232						.map(|(query_index, basis_eval)| {
233							// The reshaped `evals` "matrix" is subdivided in chunks of size
234							// `1 << (start_index + log_query_size)` for the dot-product.
235							let offset = inner_j << (start_index + log_query_size) | inner_i;
236							let eval_index = offset + (query_index << start_index);
237							let subpoly_eval_i = get_packed_slice(evals, eval_index);
238							basis_eval * subpoly_eval_i
239						})
240						.sum(),
241				);
242			}
243			out_val.write(res);
244		});
245
246	Ok(())
247}
248
249#[inline]
250fn check_fold_arguments<P, PE, POut>(
251	evals: &[P],
252	log_evals_size: usize,
253	query: &[PE],
254	log_query_size: usize,
255	out: &[POut],
256) -> Result<(), Error>
257where
258	P: PackedField,
259	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
260{
261	if log_evals_size < log_query_size {
262		bail!(Error::IncorrectQuerySize {
263			expected: log_evals_size
264		});
265	}
266
267	if P::WIDTH * evals.len() < 1 << log_evals_size {
268		bail!(Error::IncorrectArgumentLength {
269			arg: "evals".into(),
270			expected: 1 << log_evals_size
271		});
272	}
273
274	if PE::WIDTH * query.len() < 1 << log_query_size {
275		bail!(Error::IncorrectArgumentLength {
276			arg: "query".into(),
277			expected: 1 << log_query_size
278		});
279	}
280
281	if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
282		bail!(Error::IncorrectOutputPolynomialSize {
283			expected: 1 << (log_evals_size - log_query_size)
284		});
285	}
286
287	Ok(())
288}
289
290#[inline]
291fn check_right_lerp_fold_arguments<P, PE, POut>(
292	evals: &[P],
293	evals_size: usize,
294	out: &[POut],
295) -> Result<(), Error>
296where
297	P: PackedField,
298	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
299{
300	if P::WIDTH * evals.len() < evals_size {
301		bail!(Error::IncorrectArgumentLength {
302			arg: "evals".into(),
303			expected: evals_size
304		});
305	}
306
307	if PE::WIDTH * out.len() * 2 < evals_size {
308		bail!(Error::IncorrectOutputPolynomialSize {
309			expected: evals_size.div_ceil(2)
310		});
311	}
312
313	Ok(())
314}
315
316#[inline]
317fn check_left_lerp_fold_arguments<P, PE, POut>(
318	evals: &[P],
319	non_const_prefix: usize,
320	log_evals_size: usize,
321	out: &[POut],
322) -> Result<(), Error>
323where
324	P: PackedField,
325	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
326{
327	if log_evals_size == 0 {
328		bail!(Error::IncorrectQuerySize { expected: 1 });
329	}
330
331	if non_const_prefix > 1 << log_evals_size {
332		bail!(Error::IncorrectNonzeroScalarPrefix {
333			expected: 1 << log_evals_size,
334		});
335	}
336
337	if P::WIDTH * evals.len() < non_const_prefix {
338		bail!(Error::IncorrectArgumentLength {
339			arg: "evals".into(),
340			expected: non_const_prefix,
341		});
342	}
343
344	let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
345
346	if PE::WIDTH * out.len() < folded_non_const_prefix {
347		bail!(Error::IncorrectOutputPolynomialSize {
348			expected: folded_non_const_prefix,
349		});
350	}
351
352	Ok(())
353}
354
355/// Optimized version for 1-bit values with query size 0-2
356fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
357	evals: &[P],
358	query: &[PE],
359	out: &mut [PE],
360) -> bool
361where
362	P: PackedField,
363	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
364{
365	if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
366		return false;
367	}
368
369	// Cache the table for a single evaluation
370	let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
371		.map(|i| {
372			let mut result = PE::Scalar::ZERO;
373			for j in 0..1 << LOG_QUERY_SIZE {
374				if i >> j & 1 == 1 {
375					result += get_packed_slice(query, j);
376				}
377			}
378			result
379		})
380		.collect::<Vec<_>>();
381
382	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
383		out: &'a mut [PE],
384		cached_table: &'a [PE::Scalar],
385	}
386
387	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
388		for Callback<'_, PE, LOG_QUERY_SIZE>
389	{
390		#[inline(always)]
391		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
392			let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
393			let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
394			let mut current_index = 0;
395			for byte in iterator {
396				for k in 0..values_in_byte {
397					let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
398					// Safety: `i` is less than `chunk_size`
399					unsafe {
400						set_packed_slice_unchecked(
401							self.out,
402							current_index + k,
403							self.cached_table[index as usize],
404						);
405					}
406				}
407
408				current_index += values_in_byte;
409			}
410		}
411	}
412
413	let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
414		out,
415		cached_table: &cached_table,
416	};
417
418	iterate_bytes(evals, &mut callback);
419
420	true
421}
422
423/// Optimized version for 1-bit values with medium log query size (3-6)
424fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
425	evals: &[P],
426	query: &[PE],
427	out: &mut [PE],
428) -> bool
429where
430	P: PackedField,
431	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
432{
433	if LOG_QUERY_SIZE < 3 {
434		return false;
435	}
436
437	if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
438		return false;
439	}
440
441	let cached_tables =
442		create_partial_sums_lookup_tables(PackedSlice::new_with_len(query, 1 << LOG_QUERY_SIZE));
443
444	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
445		out: &'a mut [PE],
446		cached_tables: &'a [PE::Scalar],
447	}
448
449	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
450		for Callback<'_, PE, LOG_QUERY_SIZE>
451	{
452		#[inline(always)]
453		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
454			let log_tables_count = LOG_QUERY_SIZE - 3;
455			let tables_count = 1 << log_tables_count;
456			let mut current_index = 0;
457			let mut current_table = 0;
458			let mut current_value = PE::Scalar::ZERO;
459			for byte in iterator {
460				current_value += self.cached_tables[(current_table << 8) + byte as usize];
461				current_table += 1;
462
463				if current_table == tables_count {
464					// Safety: `i` is less than `chunk_size`
465					unsafe {
466						set_packed_slice_unchecked(self.out, current_index, current_value);
467					}
468					current_index += 1;
469					current_table = 0;
470					current_value = PE::Scalar::ZERO;
471				}
472			}
473		}
474	}
475
476	let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
477		out,
478		cached_tables: &cached_tables,
479	};
480
481	iterate_bytes(evals, &mut callback);
482
483	true
484}
485
486/// Try run optimized version for 1-bit values.
487/// Returns true in case when the optimized calculation was performed.
488/// Otherwise, returns false and the fallback should be used.
489fn fold_right_1bit_evals<P, PE>(
490	evals: &[P],
491	log_evals_size: usize,
492	query: &[PE],
493	log_query_size: usize,
494	out: &mut [PE],
495) -> bool
496where
497	P: PackedField,
498	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
499{
500	if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
501		return false;
502	}
503
504	// We pass log_query_size_ as a generic parameter because that allows a compiler producing
505	// more efficient code in the tight loops inside the functions.
506	match log_query_size {
507		0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
508		1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
509		2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
510		3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
511		4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
512		5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
513		6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
514		7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
515		_ => false,
516	}
517}
518
519/// Specialized implementation for a single parameter right fold using linear interpolation
520/// instead of tensor expansion resulting in  a single multiplication instead of two:
521///   f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w).
522///
523/// The same approach may be generalized to higher variable counts, with diminishing returns.
524///
525/// Please note that this method is single threaded. Currently we always have some
526/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
527/// use more efficient optimizations for special cases. If we ever need a parallel version of this
528/// function, we can implement it separately.
529pub fn fold_right_lerp<P, PE>(
530	evals: &[P],
531	evals_size: usize,
532	lerp_query: PE::Scalar,
533	suffix_eval: P::Scalar,
534	out: &mut [PE],
535) -> Result<(), Error>
536where
537	P: PackedField,
538	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
539{
540	check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
541
542	let folded_evals_size = evals_size >> 1;
543	out[..folded_evals_size.div_ceil(PE::WIDTH)]
544		.iter_mut()
545		.enumerate()
546		.for_each(|(i, packed_result_eval)| {
547			for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
548				let index = (i << PE::LOG_WIDTH) | j;
549
550				let (eval0, eval1) = unsafe {
551					(
552						get_packed_slice_unchecked(evals, index << 1),
553						get_packed_slice_unchecked(evals, (index << 1) | 1),
554					)
555				};
556
557				let result_eval =
558					PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
559
560				// Safety: `j` < `PE::WIDTH`
561				unsafe {
562					packed_result_eval.set_unchecked(j, result_eval);
563				}
564			}
565		});
566
567	if evals_size % 2 == 1 {
568		let eval0 = get_packed_slice(evals, folded_evals_size << 1);
569		set_packed_slice(
570			out,
571			folded_evals_size,
572			PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
573		);
574	}
575
576	Ok(())
577}
578
579/// Left linear interpolation (lerp, single variable) fold
580///
581/// Please note that this method is single threaded. Currently we always have some
582/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
583/// use more efficient optimizations for special cases. If we ever need a parallel version of this
584/// function, we can implement it separately.
585///
586/// Also note that left folds are often intended to be used with non-indexable packed fields that
587/// have inefficient scalar access; fully generic handling of all interesting cases that can
588/// leverage spread multiplication requires dynamically checking the `PackedExtension` relations, so
589/// for now we just handle the simplest yet important case of a single variable left fold in packed
590/// field P with a lerp query of its scalar (and not a nontrivial extension field!).
591pub fn fold_left_lerp<P>(
592	evals: &[P],
593	non_const_prefix: usize,
594	suffix_eval: P::Scalar,
595	log_evals_size: usize,
596	lerp_query: P::Scalar,
597	out: &mut [MaybeUninit<P>],
598) -> Result<(), Error>
599where
600	P: PackedField,
601{
602	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
603
604	if log_evals_size > P::LOG_WIDTH {
605		let pivot = non_const_prefix
606			.saturating_sub(1 << (log_evals_size - 1))
607			.div_ceil(P::WIDTH);
608
609		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
610		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
611
612		if pivot > 0 {
613			let (evals_0, evals_1) = evals.split_at(packed_len);
614			for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
615				out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
616			}
617		}
618
619		let broadcast_suffix_eval = P::broadcast(suffix_eval);
620		for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
621			out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
622		}
623
624		for out in &mut out[upper_bound..] {
625			out.write(P::zero());
626		}
627	} else if non_const_prefix > 0 {
628		let only_packed = *evals.first().expect("log_evals_size > 0");
629		let mut folded = P::zero();
630
631		for i in 0..1 << (log_evals_size - 1) {
632			let eval_0 = only_packed.get(i);
633			let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
634			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
635		}
636
637		out.first_mut().expect("log_evals_size > 0").write(folded);
638	}
639
640	Ok(())
641}
642
643/// Inplace left linear interpolation (lerp, single variable) fold
644///
645/// Please note that this method is single threaded. Currently we always have some
646/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
647/// use more efficient optimizations for special cases. If we ever need a parallel version of this
648/// function, we can implement it separately.
649pub fn fold_left_lerp_inplace<P>(
650	evals: &mut Vec<P>,
651	non_const_prefix: usize,
652	suffix_eval: P::Scalar,
653	log_evals_size: usize,
654	lerp_query: P::Scalar,
655) -> Result<(), Error>
656where
657	P: PackedField,
658{
659	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
660
661	if log_evals_size > P::LOG_WIDTH {
662		let pivot = non_const_prefix
663			.saturating_sub(1 << (log_evals_size - 1))
664			.div_ceil(P::WIDTH);
665
666		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
667		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
668
669		if pivot > 0 {
670			let (evals_0, evals_1) = evals.split_at_mut(packed_len);
671			for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
672				*eval_0 += (*eval_1 - *eval_0) * lerp_query;
673			}
674		}
675
676		let broadcast_suffix_eval = P::broadcast(suffix_eval);
677		for eval in &mut evals[pivot..upper_bound] {
678			*eval += (broadcast_suffix_eval - *eval) * lerp_query;
679		}
680
681		evals.truncate(upper_bound);
682	} else if non_const_prefix > 0 {
683		let only_packed = evals.first_mut().expect("log_evals_size > 0");
684		let mut folded = P::zero();
685		let half_size = 1 << (log_evals_size - 1);
686
687		for i in 0..half_size {
688			let eval_0 = only_packed.get(i);
689			let eval_1 = only_packed.get(i | half_size);
690			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
691		}
692
693		*only_packed = folded;
694	}
695
696	Ok(())
697}
698
699/// Fallback implementation for fold that can be executed for any field types and sizes.
700fn fold_right_fallback<P, PE>(
701	evals: &[P],
702	log_evals_size: usize,
703	query: &[PE],
704	log_query_size: usize,
705	out: &mut [PE],
706) where
707	P: PackedField,
708	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
709{
710	for (k, packed_result_eval) in out.iter_mut().enumerate() {
711		for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
712			let index = (k << PE::LOG_WIDTH) | j;
713
714			let offset = index << log_query_size;
715
716			let mut result_eval = PE::Scalar::ZERO;
717			for (t, query_expansion) in PackedField::iter_slice(query)
718				.take(1 << log_query_size)
719				.enumerate()
720			{
721				result_eval += query_expansion * get_packed_slice(evals, t + offset);
722			}
723
724			// Safety: `j` < `PE::WIDTH`
725			unsafe {
726				packed_result_eval.set_unchecked(j, result_eval);
727			}
728		}
729	}
730}
731
732type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
733
734#[inline(always)]
735fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
736	TypeId::of::<ArchOptimalType<F>>()
737}
738
739/// Use optimized algorithm for 1-bit evaluations with 128-bit query values packed with the optimal
740/// underlier. Returns true if the optimized algorithm was used. In this case `out` can be safely
741/// interpreted as initialized.
742///
743/// We could potentially this of specializing for other query fields, but that would require
744/// separate implementations.
745fn fold_left_1b_128b<P, PE>(
746	evals: &[P],
747	log_evals_size: usize,
748	query: &[PE],
749	log_query_size: usize,
750	out: &mut [MaybeUninit<PE>],
751) -> bool
752where
753	P: PackedField,
754	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
755{
756	if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
757		return false;
758	}
759
760	let log_row_size = log_evals_size - log_query_size;
761	if log_row_size < 3 {
762		return false;
763	}
764
765	if PE::LOG_WIDTH > 3 {
766		return false;
767	}
768
769	// Safety: the cast is safe because the type is checked by the previous if statement
770	let evals_u8: &[u8] = unsafe {
771		std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
772	};
773
774	// Try to run the optimized version for the specific 128-bit field type.
775	// This is a workaround for the lack of specialization in Rust.
776	#[inline]
777	fn try_run_specialization<PE, F>(
778		lookup_table: &[OptimalUnderlier],
779		evals_u8: &[u8],
780		log_evals_size: usize,
781		query: &[PE],
782		log_query_size: usize,
783		out: &mut [MaybeUninit<PE>],
784	) -> bool
785	where
786		PE: PackedField,
787		F: ArchOptimal,
788	{
789		if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
790			let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
791			let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
792
793			fold_left_1b_128b_impl(
794				lookup_table,
795				evals_u8,
796				log_evals_size,
797				query,
798				log_query_size,
799				out,
800			);
801			true
802		} else {
803			false
804		}
805	}
806
807	let lookup_table = &*LOOKUP_TABLE;
808	try_run_specialization::<_, BinaryField128b>(
809		lookup_table,
810		evals_u8,
811		log_evals_size,
812		query,
813		log_query_size,
814		out,
815	) || try_run_specialization::<_, AESTowerField128b>(
816		lookup_table,
817		evals_u8,
818		log_evals_size,
819		query,
820		log_query_size,
821		out,
822	) || try_run_specialization::<_, BinaryField128bPolyval>(
823		lookup_table,
824		evals_u8,
825		log_evals_size,
826		query,
827		log_query_size,
828		out,
829	)
830}
831
832/// Cast slice from unknown type to the known one assuming that the types are the same.
833#[inline(always)]
834fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
835	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
836	// Safety: the cast is safe because the type is checked by the previous if statement
837	unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
838}
839
840/// Cast slice from unknown type to the known one assuming that the types are the same.
841#[inline(always)]
842fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
843	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
844	// Safety: the cast is safe because the type is checked by the previous if statement
845	unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
846}
847
848/// Initialize the lookup table u8 -> [U; 8 / <number of 128-bit elements in U>] where
849/// each bit in `u8` corresponds to a 128-bit element in `U` filled with ones or zeros
850/// depending on the bit value. We use the values as bit masks for fast multiplication
851/// by packed BinaryField1b values.
852fn init_lookup_table_width<U>() -> Vec<U>
853where
854	U: UnderlierWithBitOps + From<u128>,
855{
856	let items_b128 = U::BITS / u128::BITS as usize;
857	assert!(items_b128 <= 8);
858	let items_in_byte = 8 / items_b128;
859
860	let mut result = Vec::with_capacity(256 * items_in_byte);
861	for i in 0..256 {
862		for j in 0..items_in_byte {
863			let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
864			let mut value = U::ZERO;
865			for k in 0..items_b128 {
866				if (bits >> k) & 1 == 1 {
867					unsafe {
868						value.set_subvalue(k, u128::ONES);
869					}
870				}
871			}
872			result.push(value);
873		}
874	}
875
876	result
877}
878
879lazy_static! {
880	static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
881}
882
883#[inline]
884fn fold_left_1b_128b_impl<PE, U>(
885	lookup_table: &[U],
886	evals: &[u8],
887	log_evals_size: usize,
888	query: &[PE],
889	log_query_size: usize,
890	out: &mut [MaybeUninit<PE>],
891) where
892	PE: PackedField + WithUnderlier<Underlier = U>,
893	U: UnderlierWithBitOps,
894{
895	let out = unsafe { slice_assume_init_mut(out) };
896	fill_zeroes(out);
897
898	let items_in_byte = 8 / PE::WIDTH;
899	let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
900	for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
901		let query_val = PE::broadcast(query_val).to_underlier();
902		for (byte_index, byte) in row_bytes.iter().enumerate() {
903			let mask_offset = *byte as usize * items_in_byte;
904			let out_offset = byte_index * items_in_byte;
905			for i in 0..items_in_byte {
906				let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
907				let multiplied = query_val & *mask;
908				let out = unsafe { out.get_unchecked_mut(out_offset + i) };
909				*out += PE::from_underlier(multiplied);
910			}
911		}
912	}
913}
914
915fn fold_left_fallback<P, PE>(
916	evals: &[P],
917	log_evals_size: usize,
918	query: &[PE],
919	log_query_size: usize,
920	out: &mut [MaybeUninit<PE>],
921) where
922	P: PackedField,
923	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
924{
925	let new_n_vars = log_evals_size - log_query_size;
926
927	out.iter_mut()
928		.enumerate()
929		.for_each(|(outer_index, out_val)| {
930			let mut res = PE::default();
931			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
932				res.set(
933					inner_index,
934					PackedField::iter_slice(query)
935						.take(1 << log_query_size)
936						.enumerate()
937						.map(|(query_index, basis_eval)| {
938							let eval_index = (query_index << new_n_vars)
939								| (outer_index << PE::LOG_WIDTH)
940								| inner_index;
941							let subpoly_eval_i = get_packed_slice(evals, eval_index);
942							basis_eval * subpoly_eval_i
943						})
944						.sum(),
945				);
946			}
947
948			out_val.write(res);
949		});
950}
951
952#[cfg(test)]
953mod tests {
954	use std::iter::repeat_with;
955
956	use binius_field::{
957		packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b,
958		PackedBinaryField16x8b, PackedBinaryField32x1b, PackedBinaryField512x1b,
959		PackedBinaryField64x8b, PackedBinaryField8x1b,
960	};
961	use rand::{rngs::StdRng, SeedableRng};
962
963	use super::*;
964
965	fn fold_right_reference<P, PE>(
966		evals: &[P],
967		log_evals_size: usize,
968		query: &[PE],
969		log_query_size: usize,
970		out: &mut [PE],
971	) where
972		P: PackedField,
973		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
974	{
975		for i in 0..1 << (log_evals_size - log_query_size) {
976			let mut result = PE::Scalar::ZERO;
977			for j in 0..1 << log_query_size {
978				result +=
979					get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
980			}
981
982			set_packed_slice(out, i, result);
983		}
984	}
985
986	fn check_fold_right<P, PE>(
987		evals: &[P],
988		log_evals_size: usize,
989		query: &[PE],
990		log_query_size: usize,
991	) where
992		P: PackedField,
993		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
994	{
995		let mut reference_out =
996			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
997		let mut out = reference_out.clone();
998
999		fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
1000		fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1001
1002		for i in 0..1 << (log_evals_size - log_query_size) {
1003			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1004		}
1005	}
1006
1007	#[test]
1008	fn test_1b_small_poly_query_log_size_0() {
1009		let mut rng = StdRng::seed_from_u64(0);
1010		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1011		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1012
1013		check_fold_right(&evals, 0, &query, 0);
1014	}
1015
1016	#[test]
1017	fn test_1b_small_poly_query_log_size_1() {
1018		let mut rng = StdRng::seed_from_u64(0);
1019		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1020		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1021
1022		check_fold_right(&evals, 2, &query, 1);
1023	}
1024
1025	#[test]
1026	fn test_1b_small_poly_query_log_size_7() {
1027		let mut rng = StdRng::seed_from_u64(0);
1028		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1029		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1030
1031		check_fold_right(&evals, 7, &query, 7);
1032	}
1033
1034	#[test]
1035	fn test_1b_many_evals() {
1036		const LOG_EVALS_SIZE: usize = 14;
1037		let mut rng = StdRng::seed_from_u64(1);
1038		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1039			.take(1 << LOG_EVALS_SIZE)
1040			.collect::<Vec<_>>();
1041		let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
1042			.take(8)
1043			.collect::<Vec<_>>();
1044
1045		for log_query_size in 0..10 {
1046			check_fold_right(
1047				&evals,
1048				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1049				&query,
1050				log_query_size,
1051			);
1052		}
1053	}
1054
1055	#[test]
1056	fn test_8b_small_poly() {
1057		const LOG_EVALS_SIZE: usize = 5;
1058		let mut rng = StdRng::seed_from_u64(0);
1059		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1060			.take(1 << LOG_EVALS_SIZE)
1061			.collect::<Vec<_>>();
1062		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1063			.take(1 << 8)
1064			.collect::<Vec<_>>();
1065
1066		for log_query_size in 0..8 {
1067			check_fold_right(
1068				&evals,
1069				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1070				&query,
1071				log_query_size,
1072			);
1073		}
1074	}
1075
1076	#[test]
1077	fn test_8b_many_evals() {
1078		const LOG_EVALS_SIZE: usize = 13;
1079		let mut rng = StdRng::seed_from_u64(0);
1080		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1081			.take(1 << LOG_EVALS_SIZE)
1082			.collect::<Vec<_>>();
1083		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1084			.take(1 << 8)
1085			.collect::<Vec<_>>();
1086
1087		for log_query_size in 0..8 {
1088			check_fold_right(
1089				&evals,
1090				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1091				&query,
1092				log_query_size,
1093			);
1094		}
1095	}
1096
1097	#[test]
1098	fn test_lower_higher_bits() {
1099		const LOG_EVALS_SIZE: usize = 7;
1100		let mut rng = StdRng::seed_from_u64(0);
1101		// We have four elements of 32 bits each. Overall, that gives us 128 bits.
1102		let evals = [
1103			PackedBinaryField32x1b::random(&mut rng),
1104			PackedBinaryField32x1b::random(&mut rng),
1105			PackedBinaryField32x1b::random(&mut rng),
1106			PackedBinaryField32x1b::random(&mut rng),
1107		];
1108
1109		// We get the lower bits of all elements (`log_query_size` = 1, so only the first
1110		// two bits of `query` are used).
1111		let log_query_size = 1;
1112		let query = [PackedBinaryField32x1b::from(
1113			0b00000000_00000000_00000000_00000001u32,
1114		)];
1115		let mut out = vec![
1116			PackedBinaryField32x1b::zero();
1117			(1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1118		];
1119		out.clear();
1120		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1121			.unwrap();
1122		unsafe {
1123			out.set_len(out.capacity());
1124		}
1125
1126		// Every 16 consecutive bits in `out` should be equal to the lower bits of an element.
1127		for (i, out_val) in out.iter().enumerate() {
1128			let first_lower = (evals[2 * i].0 as u16) as u32;
1129			let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1130			let expected_out = first_lower + (second_lower << 16);
1131			assert!(out_val.0 == expected_out);
1132		}
1133
1134		// We get the higher bits of all elements (`log_query_size` = 1, so only the first
1135		// two bits of `query` are used).
1136		let query = [PackedBinaryField32x1b::from(
1137			0b00000000_00000000_00000000_00000010u32,
1138		)];
1139
1140		out.clear();
1141		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1142			.unwrap();
1143		unsafe {
1144			out.set_len(out.capacity());
1145		}
1146
1147		// Every 16 consecutive bits in `out` should be equal to the higher bits of an element.
1148		for (i, out_val) in out.iter().enumerate() {
1149			let first_higher = evals[2 * i].0 >> 16;
1150			let second_higher = evals[2 * i + 1].0 >> 16;
1151			let expected_out = first_higher + (second_higher << 16);
1152			assert!(out_val.0 == expected_out);
1153		}
1154	}
1155
1156	#[test]
1157	fn test_zeropad_bits() {
1158		const LOG_EVALS_SIZE: usize = 5;
1159		let mut rng = StdRng::seed_from_u64(0);
1160		// We have four elements of 32 bits each. Overall, that gives us 128 bits.
1161		let evals = [
1162			PackedBinaryField8x1b::random(&mut rng),
1163			PackedBinaryField8x1b::random(&mut rng),
1164			PackedBinaryField8x1b::random(&mut rng),
1165			PackedBinaryField8x1b::random(&mut rng),
1166		];
1167
1168		// We double the size of each element by padding each element to the left.
1169		let num_extra_vars = 2;
1170		let start_index = 3;
1171
1172		for nonzero_index in 0..1 << num_extra_vars {
1173			let mut out =
1174				vec![
1175					PackedBinaryField8x1b::zero();
1176					1usize << (LOG_EVALS_SIZE + num_extra_vars - PackedBinaryField8x1b::LOG_WIDTH)
1177				];
1178			zero_pad(
1179				&evals,
1180				LOG_EVALS_SIZE,
1181				LOG_EVALS_SIZE + num_extra_vars,
1182				start_index,
1183				nonzero_index,
1184				&mut out,
1185			)
1186			.unwrap();
1187
1188			// Every `1 << start_index` consecutive bits from the original `evals` are placed at
1189			// index `nonzero_index` within a block of size `1 << (LOGS_EVALS_SIZE +
1190			// num_extra_vals)`.
1191			for (i, out_val) in out.iter().enumerate() {
1192				let expansion = 1 << num_extra_vars;
1193				if i % expansion == nonzero_index {
1194					assert!(out_val.0 == evals[i / expansion].0);
1195				} else {
1196					assert!(out_val.0 == 0);
1197				}
1198			}
1199		}
1200	}
1201
1202	fn fold_left_reference<P, PE>(
1203		evals: &[P],
1204		log_evals_size: usize,
1205		query: &[PE],
1206		log_query_size: usize,
1207		out: &mut [PE],
1208	) where
1209		P: PackedField,
1210		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1211	{
1212		for i in 0..1 << (log_evals_size - log_query_size) {
1213			let mut result = PE::Scalar::ZERO;
1214			for j in 0..1 << log_query_size {
1215				result += get_packed_slice(query, j)
1216					* get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1217			}
1218
1219			set_packed_slice(out, i, result);
1220		}
1221	}
1222
1223	fn check_fold_left<P, PE>(
1224		evals: &[P],
1225		log_evals_size: usize,
1226		query: &[PE],
1227		log_query_size: usize,
1228	) where
1229		P: PackedField,
1230		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1231	{
1232		let mut reference_out =
1233			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1234
1235		let mut out = reference_out.clone();
1236		out.clear();
1237		fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1238		unsafe {
1239			out.set_len(out.capacity());
1240		}
1241
1242		fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1243
1244		for i in 0..1 << (log_evals_size - log_query_size) {
1245			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1246		}
1247	}
1248
1249	#[test]
1250	fn test_fold_left_1b_small_poly_query_log_size_0() {
1251		let mut rng = StdRng::seed_from_u64(0);
1252		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1253		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1254
1255		check_fold_left(&evals, 0, &query, 0);
1256	}
1257
1258	#[test]
1259	fn test_fold_left_1b_small_poly_query_log_size_1() {
1260		let mut rng = StdRng::seed_from_u64(0);
1261		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1262		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1263
1264		check_fold_left(&evals, 2, &query, 1);
1265	}
1266
1267	#[test]
1268	fn test_fold_left_1b_small_poly_query_log_size_7() {
1269		let mut rng = StdRng::seed_from_u64(0);
1270		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1271		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1272
1273		check_fold_left(&evals, 7, &query, 7);
1274	}
1275
1276	#[test]
1277	fn test_fold_left_1b_many_evals() {
1278		const LOG_EVALS_SIZE: usize = 14;
1279		let mut rng = StdRng::seed_from_u64(1);
1280		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1281			.take(1 << LOG_EVALS_SIZE)
1282			.collect::<Vec<_>>();
1283		let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1284
1285		for log_query_size in 0..10 {
1286			check_fold_left(
1287				&evals,
1288				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1289				&query,
1290				log_query_size,
1291			);
1292		}
1293	}
1294
1295	type B128bOptimal = ArchOptimalType<BinaryField128b>;
1296
1297	#[test]
1298	fn test_fold_left_1b_128b_optimal() {
1299		const LOG_EVALS_SIZE: usize = 14;
1300		let mut rng = StdRng::seed_from_u64(0);
1301		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1302			.take(1 << LOG_EVALS_SIZE)
1303			.collect::<Vec<_>>();
1304		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1305			.take(1 << (10 - B128bOptimal::LOG_WIDTH))
1306			.collect::<Vec<_>>();
1307
1308		for log_query_size in 0..10 {
1309			check_fold_left(
1310				&evals,
1311				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1312				&query,
1313				log_query_size,
1314			);
1315		}
1316	}
1317
1318	#[test]
1319	fn test_fold_left_128b_128b() {
1320		const LOG_EVALS_SIZE: usize = 14;
1321		let mut rng = StdRng::seed_from_u64(0);
1322		let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1323			.take(1 << LOG_EVALS_SIZE)
1324			.collect::<Vec<_>>();
1325		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1326			.take(1 << 10)
1327			.collect::<Vec<_>>();
1328
1329		for log_query_size in 0..10 {
1330			check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1331		}
1332	}
1333
1334	#[test]
1335	fn test_fold_left_lerp_inplace_conforms_reference() {
1336		const LOG_EVALS_SIZE: usize = 14;
1337		let mut rng = StdRng::seed_from_u64(0);
1338		let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1339			.take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1340			.collect::<Vec<_>>();
1341		let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1342		let mut query =
1343			vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1344		set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1345		set_packed_slice(&mut query, 1, lerp_query);
1346
1347		for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1348			let mut out = vec![
1349				MaybeUninit::uninit();
1350				1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1351			];
1352			fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1353			fold_left_lerp_inplace(
1354				&mut evals,
1355				1 << log_evals_size,
1356				Field::ZERO,
1357				log_evals_size,
1358				lerp_query,
1359			)
1360			.unwrap();
1361
1362			for (out, &inplace) in izip!(&out, &evals) {
1363				unsafe {
1364					assert_eq!(out.assume_init(), inplace);
1365				}
1366			}
1367		}
1368	}
1369
1370	#[test]
1371	fn test_check_fold_arguments_valid() {
1372		let evals = vec![PackedBinaryField128x1b::default(); 8];
1373		let query = vec![PackedBinaryField128x1b::default(); 4];
1374		let out = vec![PackedBinaryField128x1b::default(); 4];
1375
1376		// Should pass as query and output sizes are valid
1377		let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1378		assert!(result.is_ok());
1379	}
1380
1381	#[test]
1382	fn test_check_fold_arguments_invalid_query_size() {
1383		let evals = vec![PackedBinaryField128x1b::default(); 8];
1384		let query = vec![PackedBinaryField128x1b::default(); 4];
1385		let out = vec![PackedBinaryField128x1b::default(); 4];
1386
1387		// Should fail as log_query_size > log_evals_size
1388		let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1389		assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1390	}
1391}