binius_math/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7	AESTowerField128b, BinaryField1b, BinaryField128b, BinaryField128bPolyval, ExtensionField,
8	Field, PackedField,
9	arch::{ArchOptimal, OptimalUnderlier},
10	byte_iteration::{
11		ByteIteratorCallback, can_iterate_bytes, create_partial_sums_lookup_tables,
12		is_sequential_bytes, iterate_bytes,
13	},
14	packed::{
15		PackedSlice, get_packed_slice, get_packed_slice_unchecked, set_packed_slice,
16		set_packed_slice_unchecked,
17	},
18	underlier::{UnderlierWithBitOps, WithUnderlier},
19};
20use binius_utils::{bail, mem::slice_assume_init_mut};
21use bytemuck::fill_zeroes;
22use itertools::izip;
23use lazy_static::lazy_static;
24
25use crate::Error;
26
27pub fn zero_pad<P, PE>(
28	evals: &[P],
29	log_evals_size: usize,
30	log_new_vals: usize,
31	start_index: usize,
32	nonzero_index: usize,
33	out: &mut [PE],
34) -> Result<(), Error>
35where
36	P: PackedField,
37	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
38{
39	if start_index > log_evals_size {
40		bail!(Error::IncorrectStartIndex {
41			expected: log_evals_size
42		});
43	}
44
45	let extra_vars = log_new_vals - log_evals_size;
46	if nonzero_index > 1 << extra_vars {
47		bail!(Error::IncorrectNonZeroIndex {
48			expected: 1 << extra_vars
49		});
50	}
51
52	let block_size = 1 << start_index;
53	let nb_elts_per_row = 1 << (start_index + extra_vars);
54	let length = 1 << start_index;
55
56	let start_index_in_row = nonzero_index * block_size;
57	for (outer_index, packed_result_eval) in out.iter_mut().enumerate() {
58		// Index of the Scalar at the start of the current Packed element in `out`.
59		let outer_word_start = outer_index << PE::LOG_WIDTH;
60		for inner_index in 0..min(PE::WIDTH, 1 << log_evals_size) {
61			// Index of the current Scalar in `out`.
62			let outer_scalar_index = outer_word_start + inner_index;
63			// Column index, within the reshaped `evals`, where we start the dot-product with the
64			// query elements.
65			let inner_col = outer_scalar_index % nb_elts_per_row;
66			// Row index.
67			let inner_row = outer_scalar_index / nb_elts_per_row;
68
69			if inner_col >= start_index_in_row && inner_col < start_index_in_row + block_size {
70				let eval_index = inner_row * length + inner_col - start_index_in_row;
71				let result_eval = get_packed_slice(evals, eval_index).into();
72				unsafe {
73					packed_result_eval.set_unchecked(inner_index, result_eval);
74				}
75			}
76		}
77	}
78	Ok(())
79}
80
81/// Execute the right fold operation.
82///
83/// Every consequent `1 << log_query_size` scalar values are dot-produced with the corresponding
84/// query elements. The result is stored in the `output` slice of packed values.
85///
86/// Please note that this method is single threaded. Currently we always have some
87/// parallelism above this level, so it's not a problem.
88pub fn fold_right<P, PE>(
89	evals: &[P],
90	log_evals_size: usize,
91	query: &[PE],
92	log_query_size: usize,
93	out: &mut [PE],
94) -> Result<(), Error>
95where
96	P: PackedField,
97	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
98{
99	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
100
101	// Try execute the optimized version for 1-bit values if possible
102	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
103		&& fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
104	{
105		return Ok(());
106	}
107
108	// Use linear interpolation for single variable multilinear queries.
109	let is_lerp = log_query_size == 1
110		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
111
112	if is_lerp {
113		let lerp_query = get_packed_slice(query, 1);
114		fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
115	} else {
116		fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
117	}
118
119	Ok(())
120}
121
122/// Execute the left fold operation.
123///
124/// evals is treated as a matrix with `1 << log_query_size` rows and each column is dot-produced
125/// with the corresponding query element. The results is written to the `output` slice of packed
126/// values. If the function returns `Ok(())`, then `out` can be safely interpreted as initialized.
127///
128/// Please note that this method is single threaded. Currently we always have some
129/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
130/// use more efficient optimizations for special cases. If we ever need a parallel version of this
131/// function, we can implement it separately.
132pub fn fold_left<P, PE>(
133	evals: &[P],
134	log_evals_size: usize,
135	query: &[PE],
136	log_query_size: usize,
137	out: &mut [MaybeUninit<PE>],
138) -> Result<(), Error>
139where
140	P: PackedField,
141	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
142{
143	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
144
145	if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
146		&& fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
147	{
148		return Ok(());
149	}
150
151	// Use linear interpolation for single variable multilinear queries.
152	// Unlike right folds, left folds are often used with packed fields which are not indexable
153	// and have quite slow scalar indexing methods. For now, we specialize the practically important
154	// case of single variable fold when PE == P.
155	let is_lerp = log_query_size == 1
156		&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
157		&& TypeId::of::<P>() == TypeId::of::<PE>();
158
159	if is_lerp {
160		let lerp_query = get_packed_slice(query, 1);
161		// Safety: P == PE checked above.
162		let out_p =
163			unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
164
165		let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
166		fold_left_lerp(
167			evals,
168			1 << log_evals_size,
169			P::Scalar::ZERO,
170			log_evals_size,
171			lerp_query_p,
172			out_p,
173		)?;
174	} else {
175		fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
176	}
177
178	Ok(())
179}
180
181/// Execute the middle fold operation.
182///
183/// Every consequent `1 << start_index` scalar values are considered as a row. Then each column
184/// is dot-produced in chunks (of size `1 << log_query_size`) with the `query` slice.
185/// The results are written to the `output`.
186///
187/// Please note that this method is single threaded. Currently we always have some
188/// parallelism above this level, so it's not a problem.
189pub fn fold_middle<P, PE>(
190	evals: &[P],
191	log_evals_size: usize,
192	query: &[PE],
193	log_query_size: usize,
194	start_index: usize,
195	out: &mut [MaybeUninit<PE>],
196) -> Result<(), Error>
197where
198	P: PackedField,
199	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
200{
201	check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
202
203	if log_evals_size < log_query_size + start_index {
204		bail!(Error::IncorrectStartIndex {
205			expected: log_evals_size
206		});
207	}
208
209	let lower_indices_size = 1 << start_index;
210	let new_n_vars = log_evals_size - log_query_size;
211
212	out.iter_mut()
213		.enumerate()
214		.for_each(|(outer_index, out_val)| {
215			let mut res = PE::default();
216			// Index of the Scalar at the start of the current Packed element in `out`.
217			let outer_word_start = outer_index << PE::LOG_WIDTH;
218			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
219				// Index of the current Scalar in `out`.
220				let outer_scalar_index = outer_word_start + inner_index;
221				// Column index, within the reshaped `evals`, where we start the dot-product with
222				// the query elements.
223				let inner_i = outer_scalar_index % lower_indices_size;
224				// Row index.
225				let inner_j = outer_scalar_index / lower_indices_size;
226				res.set(
227					inner_index,
228					PackedField::iter_slice(query)
229						.take(1 << log_query_size)
230						.enumerate()
231						.map(|(query_index, basis_eval)| {
232							// The reshaped `evals` "matrix" is subdivided in chunks of size
233							// `1 << (start_index + log_query_size)` for the dot-product.
234							let offset = inner_j << (start_index + log_query_size) | inner_i;
235							let eval_index = offset + (query_index << start_index);
236							let subpoly_eval_i = get_packed_slice(evals, eval_index);
237							basis_eval * subpoly_eval_i
238						})
239						.sum(),
240				);
241			}
242			out_val.write(res);
243		});
244
245	Ok(())
246}
247
248#[inline]
249fn check_fold_arguments<P, PE, POut>(
250	evals: &[P],
251	log_evals_size: usize,
252	query: &[PE],
253	log_query_size: usize,
254	out: &[POut],
255) -> Result<(), Error>
256where
257	P: PackedField,
258	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
259{
260	if log_evals_size < log_query_size {
261		bail!(Error::IncorrectQuerySize {
262			expected: log_evals_size,
263			actual: log_query_size,
264		});
265	}
266
267	if P::WIDTH * evals.len() < 1 << log_evals_size {
268		bail!(Error::IncorrectArgumentLength {
269			arg: "evals".into(),
270			expected: 1 << log_evals_size
271		});
272	}
273
274	if PE::WIDTH * query.len() < 1 << log_query_size {
275		bail!(Error::IncorrectArgumentLength {
276			arg: "query".into(),
277			expected: 1 << log_query_size
278		});
279	}
280
281	if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
282		bail!(Error::IncorrectOutputPolynomialSize {
283			expected: 1 << (log_evals_size - log_query_size)
284		});
285	}
286
287	Ok(())
288}
289
290#[inline]
291fn check_right_lerp_fold_arguments<P, PE, POut>(
292	evals: &[P],
293	evals_size: usize,
294	out: &[POut],
295) -> Result<(), Error>
296where
297	P: PackedField,
298	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
299{
300	if P::WIDTH * evals.len() < evals_size {
301		bail!(Error::IncorrectArgumentLength {
302			arg: "evals".into(),
303			expected: evals_size
304		});
305	}
306
307	if PE::WIDTH * out.len() * 2 < evals_size {
308		bail!(Error::IncorrectOutputPolynomialSize {
309			expected: evals_size.div_ceil(2)
310		});
311	}
312
313	Ok(())
314}
315
316#[inline]
317fn check_left_lerp_fold_arguments<P, PE, POut>(
318	evals: &[P],
319	non_const_prefix: usize,
320	log_evals_size: usize,
321	out: &[POut],
322) -> Result<(), Error>
323where
324	P: PackedField,
325	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
326{
327	if log_evals_size == 0 {
328		bail!(Error::IncorrectQuerySize {
329			expected: 1,
330			actual: log_evals_size,
331		});
332	}
333
334	if non_const_prefix > 1 << log_evals_size {
335		bail!(Error::IncorrectNonzeroScalarPrefix {
336			expected: 1 << log_evals_size,
337		});
338	}
339
340	if P::WIDTH * evals.len() < non_const_prefix {
341		bail!(Error::IncorrectArgumentLength {
342			arg: "evals".into(),
343			expected: non_const_prefix,
344		});
345	}
346
347	let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
348
349	if PE::WIDTH * out.len() < folded_non_const_prefix {
350		bail!(Error::IncorrectOutputPolynomialSize {
351			expected: folded_non_const_prefix,
352		});
353	}
354
355	Ok(())
356}
357
358/// Optimized version for 1-bit values with query size 0-2
359fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
360	evals: &[P],
361	query: &[PE],
362	out: &mut [PE],
363) -> bool
364where
365	P: PackedField,
366	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
367{
368	if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
369		return false;
370	}
371
372	// Cache the table for a single evaluation
373	let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
374		.map(|i| {
375			let mut result = PE::Scalar::ZERO;
376			for j in 0..1 << LOG_QUERY_SIZE {
377				if i >> j & 1 == 1 {
378					result += get_packed_slice(query, j);
379				}
380			}
381			result
382		})
383		.collect::<Vec<_>>();
384
385	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
386		out: &'a mut [PE],
387		cached_table: &'a [PE::Scalar],
388	}
389
390	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
391		for Callback<'_, PE, LOG_QUERY_SIZE>
392	{
393		#[inline(always)]
394		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
395			let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
396			let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
397			let mut current_index = 0;
398			for byte in iterator {
399				for k in 0..values_in_byte {
400					let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
401					// Safety: `i` is less than `chunk_size`
402					unsafe {
403						set_packed_slice_unchecked(
404							self.out,
405							current_index + k,
406							self.cached_table[index as usize],
407						);
408					}
409				}
410
411				current_index += values_in_byte;
412			}
413		}
414	}
415
416	let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
417		out,
418		cached_table: &cached_table,
419	};
420
421	iterate_bytes(evals, &mut callback);
422
423	true
424}
425
426/// Optimized version for 1-bit values with medium log query size (3-6)
427fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
428	evals: &[P],
429	query: &[PE],
430	out: &mut [PE],
431) -> bool
432where
433	P: PackedField,
434	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
435{
436	if LOG_QUERY_SIZE < 3 {
437		return false;
438	}
439
440	let cached_tables =
441		create_partial_sums_lookup_tables(PackedSlice::new_with_len(query, 1 << LOG_QUERY_SIZE));
442
443	struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
444		out: &'a mut [PE],
445		cached_tables: &'a [PE::Scalar],
446	}
447
448	impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
449		for Callback<'_, PE, LOG_QUERY_SIZE>
450	{
451		#[inline(always)]
452		fn call(&mut self, iterator: impl Iterator<Item = u8>) {
453			let log_tables_count = LOG_QUERY_SIZE - 3;
454			let tables_count = 1 << log_tables_count;
455			let mut current_index = 0;
456			let mut current_table = 0;
457			let mut current_value = PE::Scalar::ZERO;
458			for byte in iterator {
459				current_value += self.cached_tables[(current_table << 8) + byte as usize];
460				current_table += 1;
461
462				if current_table == tables_count {
463					// Safety: `i` is less than `chunk_size`
464					unsafe {
465						set_packed_slice_unchecked(self.out, current_index, current_value);
466					}
467					current_index += 1;
468					current_table = 0;
469					current_value = PE::Scalar::ZERO;
470				}
471			}
472		}
473	}
474
475	let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
476		out,
477		cached_tables: &cached_tables,
478	};
479
480	iterate_bytes(evals, &mut callback);
481
482	true
483}
484
485/// Try run optimized version for 1-bit values.
486/// Returns true in case when the optimized calculation was performed.
487/// Otherwise, returns false and the fallback should be used.
488fn fold_right_1bit_evals<P, PE>(
489	evals: &[P],
490	log_evals_size: usize,
491	query: &[PE],
492	log_query_size: usize,
493	out: &mut [PE],
494) -> bool
495where
496	P: PackedField,
497	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
498{
499	if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
500		return false;
501	}
502
503	// We pass log_query_size_ as a generic parameter because that allows a compiler producing
504	// more efficient code in the tight loops inside the functions.
505	match log_query_size {
506		0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
507		1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
508		2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
509		3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
510		4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
511		5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
512		6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
513		7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
514		_ => false,
515	}
516}
517
518/// Specialized implementation for a single parameter right fold using linear interpolation
519/// instead of tensor expansion resulting in  a single multiplication instead of two:
520///   f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w).
521///
522/// The same approach may be generalized to higher variable counts, with diminishing returns.
523///
524/// Please note that this method is single threaded. Currently we always have some
525/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
526/// use more efficient optimizations for special cases. If we ever need a parallel version of this
527/// function, we can implement it separately.
528pub fn fold_right_lerp<P, PE>(
529	evals: &[P],
530	evals_size: usize,
531	lerp_query: PE::Scalar,
532	suffix_eval: P::Scalar,
533	out: &mut [PE],
534) -> Result<(), Error>
535where
536	P: PackedField,
537	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
538{
539	check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
540
541	let folded_evals_size = evals_size >> 1;
542	out[..folded_evals_size.div_ceil(PE::WIDTH)]
543		.iter_mut()
544		.enumerate()
545		.for_each(|(i, packed_result_eval)| {
546			for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
547				let index = (i << PE::LOG_WIDTH) | j;
548
549				let (eval0, eval1) = unsafe {
550					(
551						get_packed_slice_unchecked(evals, index << 1),
552						get_packed_slice_unchecked(evals, (index << 1) | 1),
553					)
554				};
555
556				let result_eval =
557					PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
558
559				// Safety: `j` < `PE::WIDTH`
560				unsafe {
561					packed_result_eval.set_unchecked(j, result_eval);
562				}
563			}
564		});
565
566	if evals_size % 2 == 1 {
567		let eval0 = get_packed_slice(evals, folded_evals_size << 1);
568		set_packed_slice(
569			out,
570			folded_evals_size,
571			PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
572		);
573	}
574
575	Ok(())
576}
577
578/// Left linear interpolation (lerp, single variable) fold
579///
580/// Please note that this method is single threaded. Currently we always have some
581/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
582/// use more efficient optimizations for special cases. If we ever need a parallel version of this
583/// function, we can implement it separately.
584///
585/// Also note that left folds are often intended to be used with non-indexable packed fields that
586/// have inefficient scalar access; fully generic handling of all interesting cases that can
587/// leverage spread multiplication requires dynamically checking the `PackedExtension` relations, so
588/// for now we just handle the simplest yet important case of a single variable left fold in packed
589/// field P with a lerp query of its scalar (and not a nontrivial extension field!).
590pub fn fold_left_lerp<P>(
591	evals: &[P],
592	non_const_prefix: usize,
593	suffix_eval: P::Scalar,
594	log_evals_size: usize,
595	lerp_query: P::Scalar,
596	out: &mut [MaybeUninit<P>],
597) -> Result<(), Error>
598where
599	P: PackedField,
600{
601	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
602
603	if log_evals_size > P::LOG_WIDTH {
604		let pivot = non_const_prefix
605			.saturating_sub(1 << (log_evals_size - 1))
606			.div_ceil(P::WIDTH);
607
608		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
609		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
610
611		if pivot > 0 {
612			let (evals_0, evals_1) = evals.split_at(packed_len);
613			for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
614				out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
615			}
616		}
617
618		let broadcast_suffix_eval = P::broadcast(suffix_eval);
619		for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
620			out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
621		}
622
623		for out in &mut out[upper_bound..] {
624			out.write(P::zero());
625		}
626	} else if non_const_prefix > 0 {
627		let only_packed = *evals.first().expect("log_evals_size > 0");
628		let mut folded = P::zero();
629
630		for i in 0..1 << (log_evals_size - 1) {
631			let eval_0 = only_packed.get(i);
632			let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
633			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
634		}
635
636		out.first_mut().expect("log_evals_size > 0").write(folded);
637	}
638
639	Ok(())
640}
641
642/// Inplace left linear interpolation (lerp, single variable) fold
643///
644/// Please note that this method is single threaded. Currently we always have some
645/// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to
646/// use more efficient optimizations for special cases. If we ever need a parallel version of this
647/// function, we can implement it separately.
648pub fn fold_left_lerp_inplace<P>(
649	evals: &mut Vec<P>,
650	non_const_prefix: usize,
651	suffix_eval: P::Scalar,
652	log_evals_size: usize,
653	lerp_query: P::Scalar,
654) -> Result<(), Error>
655where
656	P: PackedField,
657{
658	check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
659
660	if log_evals_size > P::LOG_WIDTH {
661		let pivot = non_const_prefix
662			.saturating_sub(1 << (log_evals_size - 1))
663			.div_ceil(P::WIDTH);
664
665		let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
666		let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
667
668		if pivot > 0 {
669			let (evals_0, evals_1) = evals.split_at_mut(packed_len);
670			for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
671				*eval_0 += (*eval_1 - *eval_0) * lerp_query;
672			}
673		}
674
675		let broadcast_suffix_eval = P::broadcast(suffix_eval);
676		for eval in &mut evals[pivot..upper_bound] {
677			*eval += (broadcast_suffix_eval - *eval) * lerp_query;
678		}
679
680		evals.truncate(upper_bound);
681	} else if non_const_prefix > 0 {
682		let only_packed = evals.first_mut().expect("log_evals_size > 0");
683		let mut folded = P::zero();
684		let half_size = 1 << (log_evals_size - 1);
685
686		for i in 0..half_size {
687			let eval_0 = only_packed.get(i);
688			let eval_1 = only_packed.get(i | half_size);
689			folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
690		}
691
692		*only_packed = folded;
693	}
694
695	Ok(())
696}
697
698/// Fallback implementation for fold that can be executed for any field types and sizes.
699fn fold_right_fallback<P, PE>(
700	evals: &[P],
701	log_evals_size: usize,
702	query: &[PE],
703	log_query_size: usize,
704	out: &mut [PE],
705) where
706	P: PackedField,
707	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
708{
709	for (k, packed_result_eval) in out.iter_mut().enumerate() {
710		for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
711			let index = (k << PE::LOG_WIDTH) | j;
712
713			let offset = index << log_query_size;
714
715			let mut result_eval = PE::Scalar::ZERO;
716			for (t, query_expansion) in PackedField::iter_slice(query)
717				.take(1 << log_query_size)
718				.enumerate()
719			{
720				result_eval += query_expansion * get_packed_slice(evals, t + offset);
721			}
722
723			// Safety: `j` < `PE::WIDTH`
724			unsafe {
725				packed_result_eval.set_unchecked(j, result_eval);
726			}
727		}
728	}
729}
730
731type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
732
733#[inline(always)]
734fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
735	TypeId::of::<ArchOptimalType<F>>()
736}
737
738/// Use optimized algorithm for 1-bit evaluations with 128-bit query values packed with the optimal
739/// underlier. Returns true if the optimized algorithm was used. In this case `out` can be safely
740/// interpreted as initialized.
741///
742/// We could potentially this of specializing for other query fields, but that would require
743/// separate implementations.
744fn fold_left_1b_128b<P, PE>(
745	evals: &[P],
746	log_evals_size: usize,
747	query: &[PE],
748	log_query_size: usize,
749	out: &mut [MaybeUninit<PE>],
750) -> bool
751where
752	P: PackedField,
753	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
754{
755	if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
756		return false;
757	}
758
759	let log_row_size = log_evals_size - log_query_size;
760	if log_row_size < 3 {
761		return false;
762	}
763
764	if PE::LOG_WIDTH > 3 {
765		return false;
766	}
767
768	// Safety: the cast is safe because the type is checked by the previous if statement
769	let evals_u8: &[u8] = unsafe {
770		std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
771	};
772
773	// Try to run the optimized version for the specific 128-bit field type.
774	// This is a workaround for the lack of specialization in Rust.
775	#[inline]
776	fn try_run_specialization<PE, F>(
777		lookup_table: &[OptimalUnderlier],
778		evals_u8: &[u8],
779		log_evals_size: usize,
780		query: &[PE],
781		log_query_size: usize,
782		out: &mut [MaybeUninit<PE>],
783	) -> bool
784	where
785		PE: PackedField,
786		F: ArchOptimal,
787	{
788		if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
789			let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
790			let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
791
792			fold_left_1b_128b_impl(
793				lookup_table,
794				evals_u8,
795				log_evals_size,
796				query,
797				log_query_size,
798				out,
799			);
800			true
801		} else {
802			false
803		}
804	}
805
806	let lookup_table = &*LOOKUP_TABLE;
807	try_run_specialization::<_, BinaryField128b>(
808		lookup_table,
809		evals_u8,
810		log_evals_size,
811		query,
812		log_query_size,
813		out,
814	) || try_run_specialization::<_, AESTowerField128b>(
815		lookup_table,
816		evals_u8,
817		log_evals_size,
818		query,
819		log_query_size,
820		out,
821	) || try_run_specialization::<_, BinaryField128bPolyval>(
822		lookup_table,
823		evals_u8,
824		log_evals_size,
825		query,
826		log_query_size,
827		out,
828	)
829}
830
831/// Cast slice from unknown type to the known one assuming that the types are the same.
832#[inline(always)]
833fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
834	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
835	// Safety: the cast is safe because the type is checked by the previous if statement
836	unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
837}
838
839/// Cast slice from unknown type to the known one assuming that the types are the same.
840#[inline(always)]
841fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
842	assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
843	// Safety: the cast is safe because the type is checked by the previous if statement
844	unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
845}
846
847/// Initialize the lookup table u8 -> [U; 8 / <number of 128-bit elements in U>] where
848/// each bit in `u8` corresponds to a 128-bit element in `U` filled with ones or zeros
849/// depending on the bit value. We use the values as bit masks for fast multiplication
850/// by packed BinaryField1b values.
851fn init_lookup_table_width<U>() -> Vec<U>
852where
853	U: UnderlierWithBitOps + From<u128>,
854{
855	let items_b128 = U::BITS / u128::BITS as usize;
856	assert!(items_b128 <= 8);
857	let items_in_byte = 8 / items_b128;
858
859	let mut result = Vec::with_capacity(256 * items_in_byte);
860	for i in 0..256 {
861		for j in 0..items_in_byte {
862			let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
863			let mut value = U::ZERO;
864			for k in 0..items_b128 {
865				if (bits >> k) & 1 == 1 {
866					unsafe {
867						value.set_subvalue(k, u128::ONES);
868					}
869				}
870			}
871			result.push(value);
872		}
873	}
874
875	result
876}
877
878lazy_static! {
879	static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
880}
881
882#[inline]
883fn fold_left_1b_128b_impl<PE, U>(
884	lookup_table: &[U],
885	evals: &[u8],
886	log_evals_size: usize,
887	query: &[PE],
888	log_query_size: usize,
889	out: &mut [MaybeUninit<PE>],
890) where
891	PE: PackedField + WithUnderlier<Underlier = U>,
892	U: UnderlierWithBitOps,
893{
894	let out = unsafe { slice_assume_init_mut(out) };
895	fill_zeroes(out);
896
897	let items_in_byte = 8 / PE::WIDTH;
898	let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
899	for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
900		let query_val = PE::broadcast(query_val).to_underlier();
901		for (byte_index, byte) in row_bytes.iter().enumerate() {
902			let mask_offset = *byte as usize * items_in_byte;
903			let out_offset = byte_index * items_in_byte;
904			for i in 0..items_in_byte {
905				let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
906				let multiplied = query_val & *mask;
907				let out = unsafe { out.get_unchecked_mut(out_offset + i) };
908				*out += PE::from_underlier(multiplied);
909			}
910		}
911	}
912}
913
914fn fold_left_fallback<P, PE>(
915	evals: &[P],
916	log_evals_size: usize,
917	query: &[PE],
918	log_query_size: usize,
919	out: &mut [MaybeUninit<PE>],
920) where
921	P: PackedField,
922	PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
923{
924	let new_n_vars = log_evals_size - log_query_size;
925
926	out.iter_mut()
927		.enumerate()
928		.for_each(|(outer_index, out_val)| {
929			let mut res = PE::default();
930			for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
931				res.set(
932					inner_index,
933					PackedField::iter_slice(query)
934						.take(1 << log_query_size)
935						.enumerate()
936						.map(|(query_index, basis_eval)| {
937							let eval_index = (query_index << new_n_vars)
938								| (outer_index << PE::LOG_WIDTH)
939								| inner_index;
940							let subpoly_eval_i = get_packed_slice(evals, eval_index);
941							basis_eval * subpoly_eval_i
942						})
943						.sum(),
944				);
945			}
946
947			out_val.write(res);
948		});
949}
950
951#[cfg(test)]
952mod tests {
953	use std::iter::repeat_with;
954
955	use binius_field::{
956		PackedBinaryField8x1b, PackedBinaryField16x8b, PackedBinaryField16x32b,
957		PackedBinaryField32x1b, PackedBinaryField64x8b, PackedBinaryField128x1b,
958		PackedBinaryField512x1b, packed::set_packed_slice,
959	};
960	use rand::{SeedableRng, rngs::StdRng};
961
962	use super::*;
963
964	fn fold_right_reference<P, PE>(
965		evals: &[P],
966		log_evals_size: usize,
967		query: &[PE],
968		log_query_size: usize,
969		out: &mut [PE],
970	) where
971		P: PackedField,
972		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
973	{
974		for i in 0..1 << (log_evals_size - log_query_size) {
975			let mut result = PE::Scalar::ZERO;
976			for j in 0..1 << log_query_size {
977				result +=
978					get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
979			}
980
981			set_packed_slice(out, i, result);
982		}
983	}
984
985	fn check_fold_right<P, PE>(
986		evals: &[P],
987		log_evals_size: usize,
988		query: &[PE],
989		log_query_size: usize,
990	) where
991		P: PackedField,
992		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
993	{
994		let mut reference_out =
995			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
996		let mut out = reference_out.clone();
997
998		fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
999		fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1000
1001		for i in 0..1 << (log_evals_size - log_query_size) {
1002			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1003		}
1004	}
1005
1006	#[test]
1007	fn test_1b_small_poly_query_log_size_0() {
1008		let mut rng = StdRng::seed_from_u64(0);
1009		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1010		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1011
1012		check_fold_right(&evals, 0, &query, 0);
1013	}
1014
1015	#[test]
1016	fn test_1b_small_poly_query_log_size_1() {
1017		let mut rng = StdRng::seed_from_u64(0);
1018		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1019		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1020
1021		check_fold_right(&evals, 2, &query, 1);
1022	}
1023
1024	#[test]
1025	fn test_1b_small_poly_query_log_size_7() {
1026		let mut rng = StdRng::seed_from_u64(0);
1027		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1028		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1029
1030		check_fold_right(&evals, 7, &query, 7);
1031	}
1032
1033	#[test]
1034	fn test_1b_many_evals() {
1035		const LOG_EVALS_SIZE: usize = 14;
1036		let mut rng = StdRng::seed_from_u64(1);
1037		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1038			.take(1 << LOG_EVALS_SIZE)
1039			.collect::<Vec<_>>();
1040		let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
1041			.take(8)
1042			.collect::<Vec<_>>();
1043
1044		for log_query_size in 0..10 {
1045			check_fold_right(
1046				&evals,
1047				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1048				&query,
1049				log_query_size,
1050			);
1051		}
1052	}
1053
1054	#[test]
1055	fn test_8b_small_poly() {
1056		const LOG_EVALS_SIZE: usize = 5;
1057		let mut rng = StdRng::seed_from_u64(0);
1058		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1059			.take(1 << LOG_EVALS_SIZE)
1060			.collect::<Vec<_>>();
1061		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1062			.take(1 << 8)
1063			.collect::<Vec<_>>();
1064
1065		for log_query_size in 0..8 {
1066			check_fold_right(
1067				&evals,
1068				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1069				&query,
1070				log_query_size,
1071			);
1072		}
1073	}
1074
1075	#[test]
1076	fn test_8b_many_evals() {
1077		const LOG_EVALS_SIZE: usize = 13;
1078		let mut rng = StdRng::seed_from_u64(0);
1079		let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1080			.take(1 << LOG_EVALS_SIZE)
1081			.collect::<Vec<_>>();
1082		let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1083			.take(1 << 8)
1084			.collect::<Vec<_>>();
1085
1086		for log_query_size in 0..8 {
1087			check_fold_right(
1088				&evals,
1089				LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1090				&query,
1091				log_query_size,
1092			);
1093		}
1094	}
1095
1096	#[test]
1097	fn test_lower_higher_bits() {
1098		const LOG_EVALS_SIZE: usize = 7;
1099		let mut rng = StdRng::seed_from_u64(0);
1100		// We have four elements of 32 bits each. Overall, that gives us 128 bits.
1101		let evals = [
1102			PackedBinaryField32x1b::random(&mut rng),
1103			PackedBinaryField32x1b::random(&mut rng),
1104			PackedBinaryField32x1b::random(&mut rng),
1105			PackedBinaryField32x1b::random(&mut rng),
1106		];
1107
1108		// We get the lower bits of all elements (`log_query_size` = 1, so only the first
1109		// two bits of `query` are used).
1110		let log_query_size = 1;
1111		let query = [PackedBinaryField32x1b::from(
1112			0b00000000_00000000_00000000_00000001u32,
1113		)];
1114		let mut out = vec![
1115			PackedBinaryField32x1b::zero();
1116			(1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1117		];
1118		out.clear();
1119		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1120			.unwrap();
1121		unsafe {
1122			out.set_len(out.capacity());
1123		}
1124
1125		// Every 16 consecutive bits in `out` should be equal to the lower bits of an element.
1126		for (i, out_val) in out.iter().enumerate() {
1127			let first_lower = (evals[2 * i].0 as u16) as u32;
1128			let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1129			let expected_out = first_lower + (second_lower << 16);
1130			assert!(out_val.0 == expected_out);
1131		}
1132
1133		// We get the higher bits of all elements (`log_query_size` = 1, so only the first
1134		// two bits of `query` are used).
1135		let query = [PackedBinaryField32x1b::from(
1136			0b00000000_00000000_00000000_00000010u32,
1137		)];
1138
1139		out.clear();
1140		fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1141			.unwrap();
1142		unsafe {
1143			out.set_len(out.capacity());
1144		}
1145
1146		// Every 16 consecutive bits in `out` should be equal to the higher bits of an element.
1147		for (i, out_val) in out.iter().enumerate() {
1148			let first_higher = evals[2 * i].0 >> 16;
1149			let second_higher = evals[2 * i + 1].0 >> 16;
1150			let expected_out = first_higher + (second_higher << 16);
1151			assert!(out_val.0 == expected_out);
1152		}
1153	}
1154
1155	#[test]
1156	fn test_zeropad_bits() {
1157		const LOG_EVALS_SIZE: usize = 5;
1158		let mut rng = StdRng::seed_from_u64(0);
1159		// We have four elements of 32 bits each. Overall, that gives us 128 bits.
1160		let evals = [
1161			PackedBinaryField8x1b::random(&mut rng),
1162			PackedBinaryField8x1b::random(&mut rng),
1163			PackedBinaryField8x1b::random(&mut rng),
1164			PackedBinaryField8x1b::random(&mut rng),
1165		];
1166
1167		// We double the size of each element by padding each element to the left.
1168		let num_extra_vars = 2;
1169		let start_index = 3;
1170
1171		for nonzero_index in 0..1 << num_extra_vars {
1172			let mut out =
1173				vec![
1174					PackedBinaryField8x1b::zero();
1175					1usize << (LOG_EVALS_SIZE + num_extra_vars - PackedBinaryField8x1b::LOG_WIDTH)
1176				];
1177			zero_pad(
1178				&evals,
1179				LOG_EVALS_SIZE,
1180				LOG_EVALS_SIZE + num_extra_vars,
1181				start_index,
1182				nonzero_index,
1183				&mut out,
1184			)
1185			.unwrap();
1186
1187			// Every `1 << start_index` consecutive bits from the original `evals` are placed at
1188			// index `nonzero_index` within a block of size `1 << (LOGS_EVALS_SIZE +
1189			// num_extra_vals)`.
1190			for (i, out_val) in out.iter().enumerate() {
1191				let expansion = 1 << num_extra_vars;
1192				if i % expansion == nonzero_index {
1193					assert!(out_val.0 == evals[i / expansion].0);
1194				} else {
1195					assert!(out_val.0 == 0);
1196				}
1197			}
1198		}
1199	}
1200
1201	fn fold_left_reference<P, PE>(
1202		evals: &[P],
1203		log_evals_size: usize,
1204		query: &[PE],
1205		log_query_size: usize,
1206		out: &mut [PE],
1207	) where
1208		P: PackedField,
1209		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1210	{
1211		for i in 0..1 << (log_evals_size - log_query_size) {
1212			let mut result = PE::Scalar::ZERO;
1213			for j in 0..1 << log_query_size {
1214				result += get_packed_slice(query, j)
1215					* get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1216			}
1217
1218			set_packed_slice(out, i, result);
1219		}
1220	}
1221
1222	fn check_fold_left<P, PE>(
1223		evals: &[P],
1224		log_evals_size: usize,
1225		query: &[PE],
1226		log_query_size: usize,
1227	) where
1228		P: PackedField,
1229		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1230	{
1231		let mut reference_out =
1232			vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1233
1234		let mut out = reference_out.clone();
1235		out.clear();
1236		fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1237		unsafe {
1238			out.set_len(out.capacity());
1239		}
1240
1241		fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1242
1243		for i in 0..1 << (log_evals_size - log_query_size) {
1244			assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1245		}
1246	}
1247
1248	#[test]
1249	fn test_fold_left_1b_small_poly_query_log_size_0() {
1250		let mut rng = StdRng::seed_from_u64(0);
1251		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1252		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1253
1254		check_fold_left(&evals, 0, &query, 0);
1255	}
1256
1257	#[test]
1258	fn test_fold_left_1b_small_poly_query_log_size_1() {
1259		let mut rng = StdRng::seed_from_u64(0);
1260		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1261		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1262
1263		check_fold_left(&evals, 2, &query, 1);
1264	}
1265
1266	#[test]
1267	fn test_fold_left_1b_small_poly_query_log_size_7() {
1268		let mut rng = StdRng::seed_from_u64(0);
1269		let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1270		let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1271
1272		check_fold_left(&evals, 7, &query, 7);
1273	}
1274
1275	#[test]
1276	fn test_fold_left_1b_many_evals() {
1277		const LOG_EVALS_SIZE: usize = 14;
1278		let mut rng = StdRng::seed_from_u64(1);
1279		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1280			.take(1 << LOG_EVALS_SIZE)
1281			.collect::<Vec<_>>();
1282		let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1283
1284		for log_query_size in 0..10 {
1285			check_fold_left(
1286				&evals,
1287				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1288				&query,
1289				log_query_size,
1290			);
1291		}
1292	}
1293
1294	type B128bOptimal = ArchOptimalType<BinaryField128b>;
1295
1296	#[test]
1297	fn test_fold_left_1b_128b_optimal() {
1298		const LOG_EVALS_SIZE: usize = 14;
1299		let mut rng = StdRng::seed_from_u64(0);
1300		let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1301			.take(1 << LOG_EVALS_SIZE)
1302			.collect::<Vec<_>>();
1303		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1304			.take(1 << (10 - B128bOptimal::LOG_WIDTH))
1305			.collect::<Vec<_>>();
1306
1307		for log_query_size in 0..10 {
1308			check_fold_left(
1309				&evals,
1310				LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1311				&query,
1312				log_query_size,
1313			);
1314		}
1315	}
1316
1317	#[test]
1318	fn test_fold_left_128b_128b() {
1319		const LOG_EVALS_SIZE: usize = 14;
1320		let mut rng = StdRng::seed_from_u64(0);
1321		let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1322			.take(1 << LOG_EVALS_SIZE)
1323			.collect::<Vec<_>>();
1324		let query = repeat_with(|| B128bOptimal::random(&mut rng))
1325			.take(1 << 10)
1326			.collect::<Vec<_>>();
1327
1328		for log_query_size in 0..10 {
1329			check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1330		}
1331	}
1332
1333	#[test]
1334	fn test_fold_left_lerp_inplace_conforms_reference() {
1335		const LOG_EVALS_SIZE: usize = 14;
1336		let mut rng = StdRng::seed_from_u64(0);
1337		let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1338			.take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1339			.collect::<Vec<_>>();
1340		let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1341		let mut query =
1342			vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1343		set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1344		set_packed_slice(&mut query, 1, lerp_query);
1345
1346		for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1347			let mut out = vec![
1348				MaybeUninit::uninit();
1349				1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1350			];
1351			fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1352			fold_left_lerp_inplace(
1353				&mut evals,
1354				1 << log_evals_size,
1355				Field::ZERO,
1356				log_evals_size,
1357				lerp_query,
1358			)
1359			.unwrap();
1360
1361			for (out, &inplace) in izip!(&out, &evals) {
1362				unsafe {
1363					assert_eq!(out.assume_init(), inplace);
1364				}
1365			}
1366		}
1367	}
1368
1369	#[test]
1370	fn test_check_fold_arguments_valid() {
1371		let evals = vec![PackedBinaryField128x1b::default(); 8];
1372		let query = vec![PackedBinaryField128x1b::default(); 4];
1373		let out = vec![PackedBinaryField128x1b::default(); 4];
1374
1375		// Should pass as query and output sizes are valid
1376		let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1377		assert!(result.is_ok());
1378	}
1379
1380	#[test]
1381	fn test_check_fold_arguments_invalid_query_size() {
1382		let evals = vec![PackedBinaryField128x1b::default(); 8];
1383		let query = vec![PackedBinaryField128x1b::default(); 4];
1384		let out = vec![PackedBinaryField128x1b::default(); 4];
1385
1386		// Should fail as log_query_size > log_evals_size
1387		let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1388		assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1389	}
1390}