1use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7 arch::{ArchOptimal, OptimalUnderlier},
8 byte_iteration::{
9 can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes,
10 ByteIteratorCallback,
11 },
12 packed::{
13 get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
14 PackedSlice,
15 },
16 underlier::{UnderlierWithBitOps, WithUnderlier},
17 AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField,
18 Field, PackedField,
19};
20use binius_utils::bail;
21use bytemuck::fill_zeroes;
22use itertools::izip;
23use lazy_static::lazy_static;
24use stackalloc::helpers::slice_assume_init_mut;
25
26use crate::Error;
27
28pub fn zero_pad<P, PE>(
29 evals: &[P],
30 log_evals_size: usize,
31 log_new_vals: usize,
32 start_index: usize,
33 nonzero_index: usize,
34 out: &mut [PE],
35) -> Result<(), Error>
36where
37 P: PackedField,
38 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
39{
40 if start_index > log_evals_size {
41 bail!(Error::IncorrectStartIndex {
42 expected: log_evals_size
43 });
44 }
45
46 let extra_vars = log_new_vals - log_evals_size;
47 if nonzero_index > 1 << extra_vars {
48 bail!(Error::IncorrectNonZeroIndex {
49 expected: 1 << extra_vars
50 });
51 }
52
53 let block_size = 1 << start_index;
54 let nb_elts_per_row = 1 << (start_index + extra_vars);
55 let length = 1 << start_index;
56
57 let start_index_in_row = nonzero_index * block_size;
58 for (outer_index, packed_result_eval) in out.iter_mut().enumerate() {
59 let outer_word_start = outer_index << PE::LOG_WIDTH;
61 for inner_index in 0..min(PE::WIDTH, 1 << log_evals_size) {
62 let outer_scalar_index = outer_word_start + inner_index;
64 let inner_col = outer_scalar_index % nb_elts_per_row;
67 let inner_row = outer_scalar_index / nb_elts_per_row;
69
70 if inner_col >= start_index_in_row && inner_col < start_index_in_row + block_size {
71 let eval_index = inner_row * length + inner_col - start_index_in_row;
72 let result_eval = get_packed_slice(evals, eval_index).into();
73 unsafe {
74 packed_result_eval.set_unchecked(inner_index, result_eval);
75 }
76 }
77 }
78 }
79 Ok(())
80}
81
82pub fn fold_right<P, PE>(
90 evals: &[P],
91 log_evals_size: usize,
92 query: &[PE],
93 log_query_size: usize,
94 out: &mut [PE],
95) -> Result<(), Error>
96where
97 P: PackedField,
98 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
99{
100 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
101
102 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
104 && fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
105 {
106 return Ok(());
107 }
108
109 let is_lerp = log_query_size == 1
111 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
112
113 if is_lerp {
114 let lerp_query = get_packed_slice(query, 1);
115 fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
116 } else {
117 fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
118 }
119
120 Ok(())
121}
122
123pub fn fold_left<P, PE>(
134 evals: &[P],
135 log_evals_size: usize,
136 query: &[PE],
137 log_query_size: usize,
138 out: &mut [MaybeUninit<PE>],
139) -> Result<(), Error>
140where
141 P: PackedField,
142 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
143{
144 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
145
146 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
147 && fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
148 {
149 return Ok(());
150 }
151
152 let is_lerp = log_query_size == 1
157 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
158 && TypeId::of::<P>() == TypeId::of::<PE>();
159
160 if is_lerp {
161 let lerp_query = get_packed_slice(query, 1);
162 let out_p =
164 unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
165
166 let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
167 fold_left_lerp(
168 evals,
169 1 << log_evals_size,
170 P::Scalar::ZERO,
171 log_evals_size,
172 lerp_query_p,
173 out_p,
174 )?;
175 } else {
176 fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
177 }
178
179 Ok(())
180}
181
182pub fn fold_middle<P, PE>(
191 evals: &[P],
192 log_evals_size: usize,
193 query: &[PE],
194 log_query_size: usize,
195 start_index: usize,
196 out: &mut [MaybeUninit<PE>],
197) -> Result<(), Error>
198where
199 P: PackedField,
200 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
201{
202 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
203
204 if log_evals_size < log_query_size + start_index {
205 bail!(Error::IncorrectStartIndex {
206 expected: log_evals_size
207 });
208 }
209
210 let lower_indices_size = 1 << start_index;
211 let new_n_vars = log_evals_size - log_query_size;
212
213 out.iter_mut()
214 .enumerate()
215 .for_each(|(outer_index, out_val)| {
216 let mut res = PE::default();
217 let outer_word_start = outer_index << PE::LOG_WIDTH;
219 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
220 let outer_scalar_index = outer_word_start + inner_index;
222 let inner_i = outer_scalar_index % lower_indices_size;
225 let inner_j = outer_scalar_index / lower_indices_size;
227 res.set(
228 inner_index,
229 PackedField::iter_slice(query)
230 .take(1 << log_query_size)
231 .enumerate()
232 .map(|(query_index, basis_eval)| {
233 let offset = inner_j << (start_index + log_query_size) | inner_i;
236 let eval_index = offset + (query_index << start_index);
237 let subpoly_eval_i = get_packed_slice(evals, eval_index);
238 basis_eval * subpoly_eval_i
239 })
240 .sum(),
241 );
242 }
243 out_val.write(res);
244 });
245
246 Ok(())
247}
248
249#[inline]
250fn check_fold_arguments<P, PE, POut>(
251 evals: &[P],
252 log_evals_size: usize,
253 query: &[PE],
254 log_query_size: usize,
255 out: &[POut],
256) -> Result<(), Error>
257where
258 P: PackedField,
259 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
260{
261 if log_evals_size < log_query_size {
262 bail!(Error::IncorrectQuerySize {
263 expected: log_evals_size
264 });
265 }
266
267 if P::WIDTH * evals.len() < 1 << log_evals_size {
268 bail!(Error::IncorrectArgumentLength {
269 arg: "evals".into(),
270 expected: 1 << log_evals_size
271 });
272 }
273
274 if PE::WIDTH * query.len() < 1 << log_query_size {
275 bail!(Error::IncorrectArgumentLength {
276 arg: "query".into(),
277 expected: 1 << log_query_size
278 });
279 }
280
281 if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
282 bail!(Error::IncorrectOutputPolynomialSize {
283 expected: 1 << (log_evals_size - log_query_size)
284 });
285 }
286
287 Ok(())
288}
289
290#[inline]
291fn check_right_lerp_fold_arguments<P, PE, POut>(
292 evals: &[P],
293 evals_size: usize,
294 out: &[POut],
295) -> Result<(), Error>
296where
297 P: PackedField,
298 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
299{
300 if P::WIDTH * evals.len() < evals_size {
301 bail!(Error::IncorrectArgumentLength {
302 arg: "evals".into(),
303 expected: evals_size
304 });
305 }
306
307 if PE::WIDTH * out.len() * 2 < evals_size {
308 bail!(Error::IncorrectOutputPolynomialSize {
309 expected: evals_size.div_ceil(2)
310 });
311 }
312
313 Ok(())
314}
315
316#[inline]
317fn check_left_lerp_fold_arguments<P, PE, POut>(
318 evals: &[P],
319 non_const_prefix: usize,
320 log_evals_size: usize,
321 out: &[POut],
322) -> Result<(), Error>
323where
324 P: PackedField,
325 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
326{
327 if log_evals_size == 0 {
328 bail!(Error::IncorrectQuerySize { expected: 1 });
329 }
330
331 if non_const_prefix > 1 << log_evals_size {
332 bail!(Error::IncorrectNonzeroScalarPrefix {
333 expected: 1 << log_evals_size,
334 });
335 }
336
337 if P::WIDTH * evals.len() < non_const_prefix {
338 bail!(Error::IncorrectArgumentLength {
339 arg: "evals".into(),
340 expected: non_const_prefix,
341 });
342 }
343
344 let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
345
346 if PE::WIDTH * out.len() < folded_non_const_prefix {
347 bail!(Error::IncorrectOutputPolynomialSize {
348 expected: folded_non_const_prefix,
349 });
350 }
351
352 Ok(())
353}
354
355fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
357 evals: &[P],
358 query: &[PE],
359 out: &mut [PE],
360) -> bool
361where
362 P: PackedField,
363 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
364{
365 if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
366 return false;
367 }
368
369 let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
371 .map(|i| {
372 let mut result = PE::Scalar::ZERO;
373 for j in 0..1 << LOG_QUERY_SIZE {
374 if i >> j & 1 == 1 {
375 result += get_packed_slice(query, j);
376 }
377 }
378 result
379 })
380 .collect::<Vec<_>>();
381
382 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
383 out: &'a mut [PE],
384 cached_table: &'a [PE::Scalar],
385 }
386
387 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
388 for Callback<'_, PE, LOG_QUERY_SIZE>
389 {
390 #[inline(always)]
391 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
392 let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
393 let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
394 let mut current_index = 0;
395 for byte in iterator {
396 for k in 0..values_in_byte {
397 let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
398 unsafe {
400 set_packed_slice_unchecked(
401 self.out,
402 current_index + k,
403 self.cached_table[index as usize],
404 );
405 }
406 }
407
408 current_index += values_in_byte;
409 }
410 }
411 }
412
413 let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
414 out,
415 cached_table: &cached_table,
416 };
417
418 iterate_bytes(evals, &mut callback);
419
420 true
421}
422
423fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
425 evals: &[P],
426 query: &[PE],
427 out: &mut [PE],
428) -> bool
429where
430 P: PackedField,
431 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
432{
433 if LOG_QUERY_SIZE < 3 {
434 return false;
435 }
436
437 if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
438 return false;
439 }
440
441 let cached_tables =
442 create_partial_sums_lookup_tables(PackedSlice::new_with_len(query, 1 << LOG_QUERY_SIZE));
443
444 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
445 out: &'a mut [PE],
446 cached_tables: &'a [PE::Scalar],
447 }
448
449 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
450 for Callback<'_, PE, LOG_QUERY_SIZE>
451 {
452 #[inline(always)]
453 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
454 let log_tables_count = LOG_QUERY_SIZE - 3;
455 let tables_count = 1 << log_tables_count;
456 let mut current_index = 0;
457 let mut current_table = 0;
458 let mut current_value = PE::Scalar::ZERO;
459 for byte in iterator {
460 current_value += self.cached_tables[(current_table << 8) + byte as usize];
461 current_table += 1;
462
463 if current_table == tables_count {
464 unsafe {
466 set_packed_slice_unchecked(self.out, current_index, current_value);
467 }
468 current_index += 1;
469 current_table = 0;
470 current_value = PE::Scalar::ZERO;
471 }
472 }
473 }
474 }
475
476 let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
477 out,
478 cached_tables: &cached_tables,
479 };
480
481 iterate_bytes(evals, &mut callback);
482
483 true
484}
485
486fn fold_right_1bit_evals<P, PE>(
490 evals: &[P],
491 log_evals_size: usize,
492 query: &[PE],
493 log_query_size: usize,
494 out: &mut [PE],
495) -> bool
496where
497 P: PackedField,
498 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
499{
500 if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
501 return false;
502 }
503
504 match log_query_size {
507 0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
508 1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
509 2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
510 3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
511 4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
512 5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
513 6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
514 7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
515 _ => false,
516 }
517}
518
519pub fn fold_right_lerp<P, PE>(
530 evals: &[P],
531 evals_size: usize,
532 lerp_query: PE::Scalar,
533 suffix_eval: P::Scalar,
534 out: &mut [PE],
535) -> Result<(), Error>
536where
537 P: PackedField,
538 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
539{
540 check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
541
542 let folded_evals_size = evals_size >> 1;
543 out[..folded_evals_size.div_ceil(PE::WIDTH)]
544 .iter_mut()
545 .enumerate()
546 .for_each(|(i, packed_result_eval)| {
547 for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
548 let index = (i << PE::LOG_WIDTH) | j;
549
550 let (eval0, eval1) = unsafe {
551 (
552 get_packed_slice_unchecked(evals, index << 1),
553 get_packed_slice_unchecked(evals, (index << 1) | 1),
554 )
555 };
556
557 let result_eval =
558 PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
559
560 unsafe {
562 packed_result_eval.set_unchecked(j, result_eval);
563 }
564 }
565 });
566
567 if evals_size % 2 == 1 {
568 let eval0 = get_packed_slice(evals, folded_evals_size << 1);
569 set_packed_slice(
570 out,
571 folded_evals_size,
572 PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
573 );
574 }
575
576 Ok(())
577}
578
579pub fn fold_left_lerp<P>(
592 evals: &[P],
593 non_const_prefix: usize,
594 suffix_eval: P::Scalar,
595 log_evals_size: usize,
596 lerp_query: P::Scalar,
597 out: &mut [MaybeUninit<P>],
598) -> Result<(), Error>
599where
600 P: PackedField,
601{
602 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
603
604 if log_evals_size > P::LOG_WIDTH {
605 let pivot = non_const_prefix
606 .saturating_sub(1 << (log_evals_size - 1))
607 .div_ceil(P::WIDTH);
608
609 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
610 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
611
612 if pivot > 0 {
613 let (evals_0, evals_1) = evals.split_at(packed_len);
614 for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
615 out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
616 }
617 }
618
619 let broadcast_suffix_eval = P::broadcast(suffix_eval);
620 for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
621 out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
622 }
623
624 for out in &mut out[upper_bound..] {
625 out.write(P::zero());
626 }
627 } else if non_const_prefix > 0 {
628 let only_packed = *evals.first().expect("log_evals_size > 0");
629 let mut folded = P::zero();
630
631 for i in 0..1 << (log_evals_size - 1) {
632 let eval_0 = only_packed.get(i);
633 let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
634 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
635 }
636
637 out.first_mut().expect("log_evals_size > 0").write(folded);
638 }
639
640 Ok(())
641}
642
643pub fn fold_left_lerp_inplace<P>(
650 evals: &mut Vec<P>,
651 non_const_prefix: usize,
652 suffix_eval: P::Scalar,
653 log_evals_size: usize,
654 lerp_query: P::Scalar,
655) -> Result<(), Error>
656where
657 P: PackedField,
658{
659 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
660
661 if log_evals_size > P::LOG_WIDTH {
662 let pivot = non_const_prefix
663 .saturating_sub(1 << (log_evals_size - 1))
664 .div_ceil(P::WIDTH);
665
666 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
667 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
668
669 if pivot > 0 {
670 let (evals_0, evals_1) = evals.split_at_mut(packed_len);
671 for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
672 *eval_0 += (*eval_1 - *eval_0) * lerp_query;
673 }
674 }
675
676 let broadcast_suffix_eval = P::broadcast(suffix_eval);
677 for eval in &mut evals[pivot..upper_bound] {
678 *eval += (broadcast_suffix_eval - *eval) * lerp_query;
679 }
680
681 evals.truncate(upper_bound);
682 } else if non_const_prefix > 0 {
683 let only_packed = evals.first_mut().expect("log_evals_size > 0");
684 let mut folded = P::zero();
685 let half_size = 1 << (log_evals_size - 1);
686
687 for i in 0..half_size {
688 let eval_0 = only_packed.get(i);
689 let eval_1 = only_packed.get(i | half_size);
690 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
691 }
692
693 *only_packed = folded;
694 }
695
696 Ok(())
697}
698
699fn fold_right_fallback<P, PE>(
701 evals: &[P],
702 log_evals_size: usize,
703 query: &[PE],
704 log_query_size: usize,
705 out: &mut [PE],
706) where
707 P: PackedField,
708 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
709{
710 for (k, packed_result_eval) in out.iter_mut().enumerate() {
711 for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
712 let index = (k << PE::LOG_WIDTH) | j;
713
714 let offset = index << log_query_size;
715
716 let mut result_eval = PE::Scalar::ZERO;
717 for (t, query_expansion) in PackedField::iter_slice(query)
718 .take(1 << log_query_size)
719 .enumerate()
720 {
721 result_eval += query_expansion * get_packed_slice(evals, t + offset);
722 }
723
724 unsafe {
726 packed_result_eval.set_unchecked(j, result_eval);
727 }
728 }
729 }
730}
731
732type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
733
734#[inline(always)]
735fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
736 TypeId::of::<ArchOptimalType<F>>()
737}
738
739fn fold_left_1b_128b<P, PE>(
746 evals: &[P],
747 log_evals_size: usize,
748 query: &[PE],
749 log_query_size: usize,
750 out: &mut [MaybeUninit<PE>],
751) -> bool
752where
753 P: PackedField,
754 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
755{
756 if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
757 return false;
758 }
759
760 let log_row_size = log_evals_size - log_query_size;
761 if log_row_size < 3 {
762 return false;
763 }
764
765 if PE::LOG_WIDTH > 3 {
766 return false;
767 }
768
769 let evals_u8: &[u8] = unsafe {
771 std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
772 };
773
774 #[inline]
777 fn try_run_specialization<PE, F>(
778 lookup_table: &[OptimalUnderlier],
779 evals_u8: &[u8],
780 log_evals_size: usize,
781 query: &[PE],
782 log_query_size: usize,
783 out: &mut [MaybeUninit<PE>],
784 ) -> bool
785 where
786 PE: PackedField,
787 F: ArchOptimal,
788 {
789 if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
790 let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
791 let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
792
793 fold_left_1b_128b_impl(
794 lookup_table,
795 evals_u8,
796 log_evals_size,
797 query,
798 log_query_size,
799 out,
800 );
801 true
802 } else {
803 false
804 }
805 }
806
807 let lookup_table = &*LOOKUP_TABLE;
808 try_run_specialization::<_, BinaryField128b>(
809 lookup_table,
810 evals_u8,
811 log_evals_size,
812 query,
813 log_query_size,
814 out,
815 ) || try_run_specialization::<_, AESTowerField128b>(
816 lookup_table,
817 evals_u8,
818 log_evals_size,
819 query,
820 log_query_size,
821 out,
822 ) || try_run_specialization::<_, BinaryField128bPolyval>(
823 lookup_table,
824 evals_u8,
825 log_evals_size,
826 query,
827 log_query_size,
828 out,
829 )
830}
831
832#[inline(always)]
834fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
835 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
836 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
838}
839
840#[inline(always)]
842fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
843 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
844 unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
846}
847
848fn init_lookup_table_width<U>() -> Vec<U>
853where
854 U: UnderlierWithBitOps + From<u128>,
855{
856 let items_b128 = U::BITS / u128::BITS as usize;
857 assert!(items_b128 <= 8);
858 let items_in_byte = 8 / items_b128;
859
860 let mut result = Vec::with_capacity(256 * items_in_byte);
861 for i in 0..256 {
862 for j in 0..items_in_byte {
863 let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
864 let mut value = U::ZERO;
865 for k in 0..items_b128 {
866 if (bits >> k) & 1 == 1 {
867 unsafe {
868 value.set_subvalue(k, u128::ONES);
869 }
870 }
871 }
872 result.push(value);
873 }
874 }
875
876 result
877}
878
879lazy_static! {
880 static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
881}
882
883#[inline]
884fn fold_left_1b_128b_impl<PE, U>(
885 lookup_table: &[U],
886 evals: &[u8],
887 log_evals_size: usize,
888 query: &[PE],
889 log_query_size: usize,
890 out: &mut [MaybeUninit<PE>],
891) where
892 PE: PackedField + WithUnderlier<Underlier = U>,
893 U: UnderlierWithBitOps,
894{
895 let out = unsafe { slice_assume_init_mut(out) };
896 fill_zeroes(out);
897
898 let items_in_byte = 8 / PE::WIDTH;
899 let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
900 for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
901 let query_val = PE::broadcast(query_val).to_underlier();
902 for (byte_index, byte) in row_bytes.iter().enumerate() {
903 let mask_offset = *byte as usize * items_in_byte;
904 let out_offset = byte_index * items_in_byte;
905 for i in 0..items_in_byte {
906 let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
907 let multiplied = query_val & *mask;
908 let out = unsafe { out.get_unchecked_mut(out_offset + i) };
909 *out += PE::from_underlier(multiplied);
910 }
911 }
912 }
913}
914
915fn fold_left_fallback<P, PE>(
916 evals: &[P],
917 log_evals_size: usize,
918 query: &[PE],
919 log_query_size: usize,
920 out: &mut [MaybeUninit<PE>],
921) where
922 P: PackedField,
923 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
924{
925 let new_n_vars = log_evals_size - log_query_size;
926
927 out.iter_mut()
928 .enumerate()
929 .for_each(|(outer_index, out_val)| {
930 let mut res = PE::default();
931 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
932 res.set(
933 inner_index,
934 PackedField::iter_slice(query)
935 .take(1 << log_query_size)
936 .enumerate()
937 .map(|(query_index, basis_eval)| {
938 let eval_index = (query_index << new_n_vars)
939 | (outer_index << PE::LOG_WIDTH)
940 | inner_index;
941 let subpoly_eval_i = get_packed_slice(evals, eval_index);
942 basis_eval * subpoly_eval_i
943 })
944 .sum(),
945 );
946 }
947
948 out_val.write(res);
949 });
950}
951
952#[cfg(test)]
953mod tests {
954 use std::iter::repeat_with;
955
956 use binius_field::{
957 packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b,
958 PackedBinaryField16x8b, PackedBinaryField32x1b, PackedBinaryField512x1b,
959 PackedBinaryField64x8b, PackedBinaryField8x1b,
960 };
961 use rand::{rngs::StdRng, SeedableRng};
962
963 use super::*;
964
965 fn fold_right_reference<P, PE>(
966 evals: &[P],
967 log_evals_size: usize,
968 query: &[PE],
969 log_query_size: usize,
970 out: &mut [PE],
971 ) where
972 P: PackedField,
973 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
974 {
975 for i in 0..1 << (log_evals_size - log_query_size) {
976 let mut result = PE::Scalar::ZERO;
977 for j in 0..1 << log_query_size {
978 result +=
979 get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
980 }
981
982 set_packed_slice(out, i, result);
983 }
984 }
985
986 fn check_fold_right<P, PE>(
987 evals: &[P],
988 log_evals_size: usize,
989 query: &[PE],
990 log_query_size: usize,
991 ) where
992 P: PackedField,
993 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
994 {
995 let mut reference_out =
996 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
997 let mut out = reference_out.clone();
998
999 fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
1000 fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1001
1002 for i in 0..1 << (log_evals_size - log_query_size) {
1003 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1004 }
1005 }
1006
1007 #[test]
1008 fn test_1b_small_poly_query_log_size_0() {
1009 let mut rng = StdRng::seed_from_u64(0);
1010 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1011 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1012
1013 check_fold_right(&evals, 0, &query, 0);
1014 }
1015
1016 #[test]
1017 fn test_1b_small_poly_query_log_size_1() {
1018 let mut rng = StdRng::seed_from_u64(0);
1019 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1020 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1021
1022 check_fold_right(&evals, 2, &query, 1);
1023 }
1024
1025 #[test]
1026 fn test_1b_small_poly_query_log_size_7() {
1027 let mut rng = StdRng::seed_from_u64(0);
1028 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1029 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1030
1031 check_fold_right(&evals, 7, &query, 7);
1032 }
1033
1034 #[test]
1035 fn test_1b_many_evals() {
1036 const LOG_EVALS_SIZE: usize = 14;
1037 let mut rng = StdRng::seed_from_u64(1);
1038 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1039 .take(1 << LOG_EVALS_SIZE)
1040 .collect::<Vec<_>>();
1041 let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
1042 .take(8)
1043 .collect::<Vec<_>>();
1044
1045 for log_query_size in 0..10 {
1046 check_fold_right(
1047 &evals,
1048 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1049 &query,
1050 log_query_size,
1051 );
1052 }
1053 }
1054
1055 #[test]
1056 fn test_8b_small_poly() {
1057 const LOG_EVALS_SIZE: usize = 5;
1058 let mut rng = StdRng::seed_from_u64(0);
1059 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1060 .take(1 << LOG_EVALS_SIZE)
1061 .collect::<Vec<_>>();
1062 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1063 .take(1 << 8)
1064 .collect::<Vec<_>>();
1065
1066 for log_query_size in 0..8 {
1067 check_fold_right(
1068 &evals,
1069 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1070 &query,
1071 log_query_size,
1072 );
1073 }
1074 }
1075
1076 #[test]
1077 fn test_8b_many_evals() {
1078 const LOG_EVALS_SIZE: usize = 13;
1079 let mut rng = StdRng::seed_from_u64(0);
1080 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1081 .take(1 << LOG_EVALS_SIZE)
1082 .collect::<Vec<_>>();
1083 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1084 .take(1 << 8)
1085 .collect::<Vec<_>>();
1086
1087 for log_query_size in 0..8 {
1088 check_fold_right(
1089 &evals,
1090 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1091 &query,
1092 log_query_size,
1093 );
1094 }
1095 }
1096
1097 #[test]
1098 fn test_lower_higher_bits() {
1099 const LOG_EVALS_SIZE: usize = 7;
1100 let mut rng = StdRng::seed_from_u64(0);
1101 let evals = [
1103 PackedBinaryField32x1b::random(&mut rng),
1104 PackedBinaryField32x1b::random(&mut rng),
1105 PackedBinaryField32x1b::random(&mut rng),
1106 PackedBinaryField32x1b::random(&mut rng),
1107 ];
1108
1109 let log_query_size = 1;
1112 let query = [PackedBinaryField32x1b::from(
1113 0b00000000_00000000_00000000_00000001u32,
1114 )];
1115 let mut out = vec![
1116 PackedBinaryField32x1b::zero();
1117 (1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1118 ];
1119 out.clear();
1120 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1121 .unwrap();
1122 unsafe {
1123 out.set_len(out.capacity());
1124 }
1125
1126 for (i, out_val) in out.iter().enumerate() {
1128 let first_lower = (evals[2 * i].0 as u16) as u32;
1129 let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1130 let expected_out = first_lower + (second_lower << 16);
1131 assert!(out_val.0 == expected_out);
1132 }
1133
1134 let query = [PackedBinaryField32x1b::from(
1137 0b00000000_00000000_00000000_00000010u32,
1138 )];
1139
1140 out.clear();
1141 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1142 .unwrap();
1143 unsafe {
1144 out.set_len(out.capacity());
1145 }
1146
1147 for (i, out_val) in out.iter().enumerate() {
1149 let first_higher = evals[2 * i].0 >> 16;
1150 let second_higher = evals[2 * i + 1].0 >> 16;
1151 let expected_out = first_higher + (second_higher << 16);
1152 assert!(out_val.0 == expected_out);
1153 }
1154 }
1155
1156 #[test]
1157 fn test_zeropad_bits() {
1158 const LOG_EVALS_SIZE: usize = 5;
1159 let mut rng = StdRng::seed_from_u64(0);
1160 let evals = [
1162 PackedBinaryField8x1b::random(&mut rng),
1163 PackedBinaryField8x1b::random(&mut rng),
1164 PackedBinaryField8x1b::random(&mut rng),
1165 PackedBinaryField8x1b::random(&mut rng),
1166 ];
1167
1168 let num_extra_vars = 2;
1170 let start_index = 3;
1171
1172 for nonzero_index in 0..1 << num_extra_vars {
1173 let mut out =
1174 vec![
1175 PackedBinaryField8x1b::zero();
1176 1usize << (LOG_EVALS_SIZE + num_extra_vars - PackedBinaryField8x1b::LOG_WIDTH)
1177 ];
1178 zero_pad(
1179 &evals,
1180 LOG_EVALS_SIZE,
1181 LOG_EVALS_SIZE + num_extra_vars,
1182 start_index,
1183 nonzero_index,
1184 &mut out,
1185 )
1186 .unwrap();
1187
1188 for (i, out_val) in out.iter().enumerate() {
1192 let expansion = 1 << num_extra_vars;
1193 if i % expansion == nonzero_index {
1194 assert!(out_val.0 == evals[i / expansion].0);
1195 } else {
1196 assert!(out_val.0 == 0);
1197 }
1198 }
1199 }
1200 }
1201
1202 fn fold_left_reference<P, PE>(
1203 evals: &[P],
1204 log_evals_size: usize,
1205 query: &[PE],
1206 log_query_size: usize,
1207 out: &mut [PE],
1208 ) where
1209 P: PackedField,
1210 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1211 {
1212 for i in 0..1 << (log_evals_size - log_query_size) {
1213 let mut result = PE::Scalar::ZERO;
1214 for j in 0..1 << log_query_size {
1215 result += get_packed_slice(query, j)
1216 * get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1217 }
1218
1219 set_packed_slice(out, i, result);
1220 }
1221 }
1222
1223 fn check_fold_left<P, PE>(
1224 evals: &[P],
1225 log_evals_size: usize,
1226 query: &[PE],
1227 log_query_size: usize,
1228 ) where
1229 P: PackedField,
1230 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1231 {
1232 let mut reference_out =
1233 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1234
1235 let mut out = reference_out.clone();
1236 out.clear();
1237 fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1238 unsafe {
1239 out.set_len(out.capacity());
1240 }
1241
1242 fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1243
1244 for i in 0..1 << (log_evals_size - log_query_size) {
1245 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1246 }
1247 }
1248
1249 #[test]
1250 fn test_fold_left_1b_small_poly_query_log_size_0() {
1251 let mut rng = StdRng::seed_from_u64(0);
1252 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1253 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1254
1255 check_fold_left(&evals, 0, &query, 0);
1256 }
1257
1258 #[test]
1259 fn test_fold_left_1b_small_poly_query_log_size_1() {
1260 let mut rng = StdRng::seed_from_u64(0);
1261 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1262 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1263
1264 check_fold_left(&evals, 2, &query, 1);
1265 }
1266
1267 #[test]
1268 fn test_fold_left_1b_small_poly_query_log_size_7() {
1269 let mut rng = StdRng::seed_from_u64(0);
1270 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1271 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1272
1273 check_fold_left(&evals, 7, &query, 7);
1274 }
1275
1276 #[test]
1277 fn test_fold_left_1b_many_evals() {
1278 const LOG_EVALS_SIZE: usize = 14;
1279 let mut rng = StdRng::seed_from_u64(1);
1280 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1281 .take(1 << LOG_EVALS_SIZE)
1282 .collect::<Vec<_>>();
1283 let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1284
1285 for log_query_size in 0..10 {
1286 check_fold_left(
1287 &evals,
1288 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1289 &query,
1290 log_query_size,
1291 );
1292 }
1293 }
1294
1295 type B128bOptimal = ArchOptimalType<BinaryField128b>;
1296
1297 #[test]
1298 fn test_fold_left_1b_128b_optimal() {
1299 const LOG_EVALS_SIZE: usize = 14;
1300 let mut rng = StdRng::seed_from_u64(0);
1301 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1302 .take(1 << LOG_EVALS_SIZE)
1303 .collect::<Vec<_>>();
1304 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1305 .take(1 << (10 - B128bOptimal::LOG_WIDTH))
1306 .collect::<Vec<_>>();
1307
1308 for log_query_size in 0..10 {
1309 check_fold_left(
1310 &evals,
1311 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1312 &query,
1313 log_query_size,
1314 );
1315 }
1316 }
1317
1318 #[test]
1319 fn test_fold_left_128b_128b() {
1320 const LOG_EVALS_SIZE: usize = 14;
1321 let mut rng = StdRng::seed_from_u64(0);
1322 let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1323 .take(1 << LOG_EVALS_SIZE)
1324 .collect::<Vec<_>>();
1325 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1326 .take(1 << 10)
1327 .collect::<Vec<_>>();
1328
1329 for log_query_size in 0..10 {
1330 check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1331 }
1332 }
1333
1334 #[test]
1335 fn test_fold_left_lerp_inplace_conforms_reference() {
1336 const LOG_EVALS_SIZE: usize = 14;
1337 let mut rng = StdRng::seed_from_u64(0);
1338 let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1339 .take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1340 .collect::<Vec<_>>();
1341 let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1342 let mut query =
1343 vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1344 set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1345 set_packed_slice(&mut query, 1, lerp_query);
1346
1347 for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1348 let mut out = vec![
1349 MaybeUninit::uninit();
1350 1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1351 ];
1352 fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1353 fold_left_lerp_inplace(
1354 &mut evals,
1355 1 << log_evals_size,
1356 Field::ZERO,
1357 log_evals_size,
1358 lerp_query,
1359 )
1360 .unwrap();
1361
1362 for (out, &inplace) in izip!(&out, &evals) {
1363 unsafe {
1364 assert_eq!(out.assume_init(), inplace);
1365 }
1366 }
1367 }
1368 }
1369
1370 #[test]
1371 fn test_check_fold_arguments_valid() {
1372 let evals = vec![PackedBinaryField128x1b::default(); 8];
1373 let query = vec![PackedBinaryField128x1b::default(); 4];
1374 let out = vec![PackedBinaryField128x1b::default(); 4];
1375
1376 let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1378 assert!(result.is_ok());
1379 }
1380
1381 #[test]
1382 fn test_check_fold_arguments_invalid_query_size() {
1383 let evals = vec![PackedBinaryField128x1b::default(); 8];
1384 let query = vec![PackedBinaryField128x1b::default(); 4];
1385 let out = vec![PackedBinaryField128x1b::default(); 4];
1386
1387 let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1389 assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1390 }
1391}