1use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7 arch::{ArchOptimal, OptimalUnderlier},
8 byte_iteration::{
9 can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes,
10 ByteIteratorCallback, PackedSlice,
11 },
12 packed::{
13 get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
14 },
15 underlier::{UnderlierWithBitOps, WithUnderlier},
16 AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField,
17 Field, PackedField,
18};
19use binius_utils::bail;
20use bytemuck::fill_zeroes;
21use itertools::izip;
22use lazy_static::lazy_static;
23use stackalloc::helpers::slice_assume_init_mut;
24
25use crate::Error;
26
27pub fn fold_right<P, PE>(
35 evals: &[P],
36 log_evals_size: usize,
37 query: &[PE],
38 log_query_size: usize,
39 out: &mut [PE],
40) -> Result<(), Error>
41where
42 P: PackedField,
43 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
44{
45 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
46
47 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
49 && fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
50 {
51 return Ok(());
52 }
53
54 let is_lerp = log_query_size == 1
56 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
57
58 if is_lerp {
59 let lerp_query = get_packed_slice(query, 1);
60 fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
61 } else {
62 fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
63 }
64
65 Ok(())
66}
67
68pub fn fold_left<P, PE>(
79 evals: &[P],
80 log_evals_size: usize,
81 query: &[PE],
82 log_query_size: usize,
83 out: &mut [MaybeUninit<PE>],
84) -> Result<(), Error>
85where
86 P: PackedField,
87 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
88{
89 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
90
91 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
92 && fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
93 {
94 return Ok(());
95 }
96
97 let is_lerp = log_query_size == 1
102 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
103 && TypeId::of::<P>() == TypeId::of::<PE>();
104
105 if is_lerp {
106 let lerp_query = get_packed_slice(query, 1);
107 let out_p =
109 unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
110
111 let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
112 fold_left_lerp(
113 evals,
114 1 << log_evals_size,
115 P::Scalar::ZERO,
116 log_evals_size,
117 lerp_query_p,
118 out_p,
119 )?;
120 } else {
121 fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
122 }
123
124 Ok(())
125}
126
127pub fn fold_middle<P, PE>(
136 evals: &[P],
137 log_evals_size: usize,
138 query: &[PE],
139 log_query_size: usize,
140 start_index: usize,
141 out: &mut [MaybeUninit<PE>],
142) -> Result<(), Error>
143where
144 P: PackedField,
145 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
146{
147 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
148
149 if log_evals_size < log_query_size + start_index {
150 bail!(Error::IncorrectStartIndex {
151 expected: log_evals_size
152 });
153 }
154
155 let lower_indices_size = 1 << start_index;
156 let new_n_vars = log_evals_size - log_query_size;
157
158 out.iter_mut()
159 .enumerate()
160 .for_each(|(outer_index, out_val)| {
161 let mut res = PE::default();
162 let outer_word_start = outer_index << PE::LOG_WIDTH;
164 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
165 let outer_scalar_index = outer_word_start + inner_index;
167 let inner_i = outer_scalar_index % lower_indices_size;
169 let inner_j = outer_scalar_index / lower_indices_size;
171 res.set(
172 inner_index,
173 PackedField::iter_slice(query)
174 .take(1 << log_query_size)
175 .enumerate()
176 .map(|(query_index, basis_eval)| {
177 let offset = inner_j << (start_index + log_query_size) | inner_i;
180 let eval_index = offset + (query_index << start_index);
181 let subpoly_eval_i = get_packed_slice(evals, eval_index);
182 basis_eval * subpoly_eval_i
183 })
184 .sum(),
185 );
186 }
187 out_val.write(res);
188 });
189
190 Ok(())
191}
192
193#[inline]
194fn check_fold_arguments<P, PE, POut>(
195 evals: &[P],
196 log_evals_size: usize,
197 query: &[PE],
198 log_query_size: usize,
199 out: &[POut],
200) -> Result<(), Error>
201where
202 P: PackedField,
203 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
204{
205 if log_evals_size < log_query_size {
206 bail!(Error::IncorrectQuerySize {
207 expected: log_evals_size
208 });
209 }
210
211 if P::WIDTH * evals.len() < 1 << log_evals_size {
212 bail!(Error::IncorrectArgumentLength {
213 arg: "evals".into(),
214 expected: 1 << log_evals_size
215 });
216 }
217
218 if PE::WIDTH * query.len() < 1 << log_query_size {
219 bail!(Error::IncorrectArgumentLength {
220 arg: "query".into(),
221 expected: 1 << log_query_size
222 });
223 }
224
225 if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
226 bail!(Error::IncorrectOutputPolynomialSize {
227 expected: 1 << (log_evals_size - log_query_size)
228 });
229 }
230
231 Ok(())
232}
233
234#[inline]
235fn check_right_lerp_fold_arguments<P, PE, POut>(
236 evals: &[P],
237 evals_size: usize,
238 out: &[POut],
239) -> Result<(), Error>
240where
241 P: PackedField,
242 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
243{
244 if P::WIDTH * evals.len() < evals_size {
245 bail!(Error::IncorrectArgumentLength {
246 arg: "evals".into(),
247 expected: evals_size
248 });
249 }
250
251 if PE::WIDTH * out.len() * 2 < evals_size {
252 bail!(Error::IncorrectOutputPolynomialSize {
253 expected: evals_size.div_ceil(2)
254 });
255 }
256
257 Ok(())
258}
259
260#[inline]
261fn check_left_lerp_fold_arguments<P, PE, POut>(
262 evals: &[P],
263 non_const_prefix: usize,
264 log_evals_size: usize,
265 out: &[POut],
266) -> Result<(), Error>
267where
268 P: PackedField,
269 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
270{
271 if log_evals_size == 0 {
272 bail!(Error::IncorrectQuerySize { expected: 1 });
273 }
274
275 if non_const_prefix > 1 << log_evals_size {
276 bail!(Error::IncorrectNonzeroScalarPrefix {
277 expected: 1 << log_evals_size,
278 });
279 }
280
281 if P::WIDTH * evals.len() < non_const_prefix {
282 bail!(Error::IncorrectArgumentLength {
283 arg: "evals".into(),
284 expected: non_const_prefix,
285 });
286 }
287
288 let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
289
290 if PE::WIDTH * out.len() < folded_non_const_prefix {
291 bail!(Error::IncorrectOutputPolynomialSize {
292 expected: folded_non_const_prefix,
293 });
294 }
295
296 Ok(())
297}
298
299fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
301 evals: &[P],
302 query: &[PE],
303 out: &mut [PE],
304) -> bool
305where
306 P: PackedField,
307 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
308{
309 if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
310 return false;
311 }
312
313 let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
315 .map(|i| {
316 let mut result = PE::Scalar::ZERO;
317 for j in 0..1 << LOG_QUERY_SIZE {
318 if i >> j & 1 == 1 {
319 result += get_packed_slice(query, j);
320 }
321 }
322 result
323 })
324 .collect::<Vec<_>>();
325
326 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
327 out: &'a mut [PE],
328 cached_table: &'a [PE::Scalar],
329 }
330
331 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
332 for Callback<'_, PE, LOG_QUERY_SIZE>
333 {
334 #[inline(always)]
335 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
336 let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
337 let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
338 let mut current_index = 0;
339 for byte in iterator {
340 for k in 0..values_in_byte {
341 let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
342 unsafe {
344 set_packed_slice_unchecked(
345 self.out,
346 current_index + k,
347 self.cached_table[index as usize],
348 );
349 }
350 }
351
352 current_index += values_in_byte;
353 }
354 }
355 }
356
357 let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
358 out,
359 cached_table: &cached_table,
360 };
361
362 iterate_bytes(evals, &mut callback);
363
364 true
365}
366
367fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
369 evals: &[P],
370 query: &[PE],
371 out: &mut [PE],
372) -> bool
373where
374 P: PackedField,
375 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
376{
377 if LOG_QUERY_SIZE < 3 {
378 return false;
379 }
380
381 if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
382 return false;
383 }
384
385 let cached_tables =
386 create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE));
387
388 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
389 out: &'a mut [PE],
390 cached_tables: &'a [PE::Scalar],
391 }
392
393 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
394 for Callback<'_, PE, LOG_QUERY_SIZE>
395 {
396 #[inline(always)]
397 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
398 let log_tables_count = LOG_QUERY_SIZE - 3;
399 let tables_count = 1 << log_tables_count;
400 let mut current_index = 0;
401 let mut current_table = 0;
402 let mut current_value = PE::Scalar::ZERO;
403 for byte in iterator {
404 current_value += self.cached_tables[(current_table << 8) + byte as usize];
405 current_table += 1;
406
407 if current_table == tables_count {
408 unsafe {
410 set_packed_slice_unchecked(self.out, current_index, current_value);
411 }
412 current_index += 1;
413 current_table = 0;
414 current_value = PE::Scalar::ZERO;
415 }
416 }
417 }
418 }
419
420 let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
421 out,
422 cached_tables: &cached_tables,
423 };
424
425 iterate_bytes(evals, &mut callback);
426
427 true
428}
429
430fn fold_right_1bit_evals<P, PE>(
434 evals: &[P],
435 log_evals_size: usize,
436 query: &[PE],
437 log_query_size: usize,
438 out: &mut [PE],
439) -> bool
440where
441 P: PackedField,
442 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
443{
444 if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
445 return false;
446 }
447
448 match log_query_size {
451 0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
452 1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
453 2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
454 3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
455 4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
456 5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
457 6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
458 7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
459 _ => false,
460 }
461}
462
463pub fn fold_right_lerp<P, PE>(
474 evals: &[P],
475 evals_size: usize,
476 lerp_query: PE::Scalar,
477 suffix_eval: P::Scalar,
478 out: &mut [PE],
479) -> Result<(), Error>
480where
481 P: PackedField,
482 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
483{
484 check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
485
486 let folded_evals_size = evals_size >> 1;
487 out[..folded_evals_size.div_ceil(PE::WIDTH)]
488 .iter_mut()
489 .enumerate()
490 .for_each(|(i, packed_result_eval)| {
491 for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
492 let index = (i << PE::LOG_WIDTH) | j;
493
494 let (eval0, eval1) = unsafe {
495 (
496 get_packed_slice_unchecked(evals, index << 1),
497 get_packed_slice_unchecked(evals, (index << 1) | 1),
498 )
499 };
500
501 let result_eval =
502 PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
503
504 unsafe {
506 packed_result_eval.set_unchecked(j, result_eval);
507 }
508 }
509 });
510
511 if evals_size % 2 == 1 {
512 let eval0 = get_packed_slice(evals, folded_evals_size << 1);
513 set_packed_slice(
514 out,
515 folded_evals_size,
516 PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
517 );
518 }
519
520 Ok(())
521}
522
523pub fn fold_left_lerp<P>(
536 evals: &[P],
537 non_const_prefix: usize,
538 suffix_eval: P::Scalar,
539 log_evals_size: usize,
540 lerp_query: P::Scalar,
541 out: &mut [MaybeUninit<P>],
542) -> Result<(), Error>
543where
544 P: PackedField,
545{
546 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
547
548 if log_evals_size > P::LOG_WIDTH {
549 let pivot = non_const_prefix
550 .saturating_sub(1 << (log_evals_size - 1))
551 .div_ceil(P::WIDTH);
552
553 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
554 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
555
556 if pivot > 0 {
557 let (evals_0, evals_1) = evals.split_at(packed_len);
558 for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
559 out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
560 }
561 }
562
563 let broadcast_suffix_eval = P::broadcast(suffix_eval);
564 for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
565 out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
566 }
567
568 for out in &mut out[upper_bound..] {
569 out.write(P::zero());
570 }
571 } else if non_const_prefix > 0 {
572 let only_packed = *evals.first().expect("log_evals_size > 0");
573 let mut folded = P::zero();
574
575 for i in 0..1 << (log_evals_size - 1) {
576 let eval_0 = only_packed.get(i);
577 let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
578 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
579 }
580
581 out.first_mut().expect("log_evals_size > 0").write(folded);
582 }
583
584 Ok(())
585}
586
587pub fn fold_left_lerp_inplace<P>(
594 evals: &mut Vec<P>,
595 non_const_prefix: usize,
596 suffix_eval: P::Scalar,
597 log_evals_size: usize,
598 lerp_query: P::Scalar,
599) -> Result<(), Error>
600where
601 P: PackedField,
602{
603 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
604
605 if log_evals_size > P::LOG_WIDTH {
606 let pivot = non_const_prefix
607 .saturating_sub(1 << (log_evals_size - 1))
608 .div_ceil(P::WIDTH);
609
610 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
611 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
612
613 if pivot > 0 {
614 let (evals_0, evals_1) = evals.split_at_mut(packed_len);
615 for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
616 *eval_0 += (*eval_1 - *eval_0) * lerp_query;
617 }
618 }
619
620 let broadcast_suffix_eval = P::broadcast(suffix_eval);
621 for eval in &mut evals[pivot..upper_bound] {
622 *eval += (broadcast_suffix_eval - *eval) * lerp_query;
623 }
624
625 evals.truncate(upper_bound);
626 } else if non_const_prefix > 0 {
627 let only_packed = evals.first_mut().expect("log_evals_size > 0");
628 let mut folded = P::zero();
629 let half_size = 1 << (log_evals_size - 1);
630
631 for i in 0..half_size {
632 let eval_0 = only_packed.get(i);
633 let eval_1 = only_packed.get(i | half_size);
634 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
635 }
636
637 *only_packed = folded;
638 }
639
640 Ok(())
641}
642
643fn fold_right_fallback<P, PE>(
645 evals: &[P],
646 log_evals_size: usize,
647 query: &[PE],
648 log_query_size: usize,
649 out: &mut [PE],
650) where
651 P: PackedField,
652 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
653{
654 for (k, packed_result_eval) in out.iter_mut().enumerate() {
655 for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
656 let index = (k << PE::LOG_WIDTH) | j;
657
658 let offset = index << log_query_size;
659
660 let mut result_eval = PE::Scalar::ZERO;
661 for (t, query_expansion) in PackedField::iter_slice(query)
662 .take(1 << log_query_size)
663 .enumerate()
664 {
665 result_eval += query_expansion * get_packed_slice(evals, t + offset);
666 }
667
668 unsafe {
670 packed_result_eval.set_unchecked(j, result_eval);
671 }
672 }
673 }
674}
675
676type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
677
678#[inline(always)]
679fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
680 TypeId::of::<ArchOptimalType<F>>()
681}
682
683fn fold_left_1b_128b<P, PE>(
690 evals: &[P],
691 log_evals_size: usize,
692 query: &[PE],
693 log_query_size: usize,
694 out: &mut [MaybeUninit<PE>],
695) -> bool
696where
697 P: PackedField,
698 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
699{
700 if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
701 return false;
702 }
703
704 let log_row_size = log_evals_size - log_query_size;
705 if log_row_size < 3 {
706 return false;
707 }
708
709 if PE::LOG_WIDTH > 3 {
710 return false;
711 }
712
713 let evals_u8: &[u8] = unsafe {
715 std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
716 };
717
718 #[inline]
721 fn try_run_specialization<PE, F>(
722 lookup_table: &[OptimalUnderlier],
723 evals_u8: &[u8],
724 log_evals_size: usize,
725 query: &[PE],
726 log_query_size: usize,
727 out: &mut [MaybeUninit<PE>],
728 ) -> bool
729 where
730 PE: PackedField,
731 F: ArchOptimal,
732 {
733 if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
734 let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
735 let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
736
737 fold_left_1b_128b_impl(
738 lookup_table,
739 evals_u8,
740 log_evals_size,
741 query,
742 log_query_size,
743 out,
744 );
745 true
746 } else {
747 false
748 }
749 }
750
751 let lookup_table = &*LOOKUP_TABLE;
752 try_run_specialization::<_, BinaryField128b>(
753 lookup_table,
754 evals_u8,
755 log_evals_size,
756 query,
757 log_query_size,
758 out,
759 ) || try_run_specialization::<_, AESTowerField128b>(
760 lookup_table,
761 evals_u8,
762 log_evals_size,
763 query,
764 log_query_size,
765 out,
766 ) || try_run_specialization::<_, BinaryField128bPolyval>(
767 lookup_table,
768 evals_u8,
769 log_evals_size,
770 query,
771 log_query_size,
772 out,
773 )
774}
775
776#[inline(always)]
778fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
779 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
780 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
782}
783
784#[inline(always)]
786fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
787 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
788 unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
790}
791
792fn init_lookup_table_width<U>() -> Vec<U>
797where
798 U: UnderlierWithBitOps + From<u128>,
799{
800 let items_b128 = U::BITS / u128::BITS as usize;
801 assert!(items_b128 <= 8);
802 let items_in_byte = 8 / items_b128;
803
804 let mut result = Vec::with_capacity(256 * items_in_byte);
805 for i in 0..256 {
806 for j in 0..items_in_byte {
807 let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
808 let mut value = U::ZERO;
809 for k in 0..items_b128 {
810 if (bits >> k) & 1 == 1 {
811 unsafe {
812 value.set_subvalue(k, u128::ONES);
813 }
814 }
815 }
816 result.push(value);
817 }
818 }
819
820 result
821}
822
823lazy_static! {
824 static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
825}
826
827#[inline]
828fn fold_left_1b_128b_impl<PE, U>(
829 lookup_table: &[U],
830 evals: &[u8],
831 log_evals_size: usize,
832 query: &[PE],
833 log_query_size: usize,
834 out: &mut [MaybeUninit<PE>],
835) where
836 PE: PackedField + WithUnderlier<Underlier = U>,
837 U: UnderlierWithBitOps,
838{
839 let out = unsafe { slice_assume_init_mut(out) };
840 fill_zeroes(out);
841
842 let items_in_byte = 8 / PE::WIDTH;
843 let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
844 for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
845 let query_val = PE::broadcast(query_val).to_underlier();
846 for (byte_index, byte) in row_bytes.iter().enumerate() {
847 let mask_offset = *byte as usize * items_in_byte;
848 let out_offset = byte_index * items_in_byte;
849 for i in 0..items_in_byte {
850 let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
851 let multiplied = query_val & *mask;
852 let out = unsafe { out.get_unchecked_mut(out_offset + i) };
853 *out += PE::from_underlier(multiplied);
854 }
855 }
856 }
857}
858
859fn fold_left_fallback<P, PE>(
860 evals: &[P],
861 log_evals_size: usize,
862 query: &[PE],
863 log_query_size: usize,
864 out: &mut [MaybeUninit<PE>],
865) where
866 P: PackedField,
867 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
868{
869 let new_n_vars = log_evals_size - log_query_size;
870
871 out.iter_mut()
872 .enumerate()
873 .for_each(|(outer_index, out_val)| {
874 let mut res = PE::default();
875 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
876 res.set(
877 inner_index,
878 PackedField::iter_slice(query)
879 .take(1 << log_query_size)
880 .enumerate()
881 .map(|(query_index, basis_eval)| {
882 let eval_index = (query_index << new_n_vars)
883 | (outer_index << PE::LOG_WIDTH)
884 | inner_index;
885 let subpoly_eval_i = get_packed_slice(evals, eval_index);
886 basis_eval * subpoly_eval_i
887 })
888 .sum(),
889 );
890 }
891
892 out_val.write(res);
893 });
894}
895
896#[cfg(test)]
897mod tests {
898 use std::iter::repeat_with;
899
900 use binius_field::{
901 packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b,
902 PackedBinaryField16x8b, PackedBinaryField32x1b, PackedBinaryField512x1b,
903 PackedBinaryField64x8b,
904 };
905 use rand::{rngs::StdRng, SeedableRng};
906
907 use super::*;
908
909 fn fold_right_reference<P, PE>(
910 evals: &[P],
911 log_evals_size: usize,
912 query: &[PE],
913 log_query_size: usize,
914 out: &mut [PE],
915 ) where
916 P: PackedField,
917 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
918 {
919 for i in 0..1 << (log_evals_size - log_query_size) {
920 let mut result = PE::Scalar::ZERO;
921 for j in 0..1 << log_query_size {
922 result +=
923 get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
924 }
925
926 set_packed_slice(out, i, result);
927 }
928 }
929
930 fn check_fold_right<P, PE>(
931 evals: &[P],
932 log_evals_size: usize,
933 query: &[PE],
934 log_query_size: usize,
935 ) where
936 P: PackedField,
937 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
938 {
939 let mut reference_out =
940 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
941 let mut out = reference_out.clone();
942
943 fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
944 fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
945
946 for i in 0..1 << (log_evals_size - log_query_size) {
947 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
948 }
949 }
950
951 #[test]
952 fn test_1b_small_poly_query_log_size_0() {
953 let mut rng = StdRng::seed_from_u64(0);
954 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
955 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
956
957 check_fold_right(&evals, 0, &query, 0);
958 }
959
960 #[test]
961 fn test_1b_small_poly_query_log_size_1() {
962 let mut rng = StdRng::seed_from_u64(0);
963 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
964 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
965
966 check_fold_right(&evals, 2, &query, 1);
967 }
968
969 #[test]
970 fn test_1b_small_poly_query_log_size_7() {
971 let mut rng = StdRng::seed_from_u64(0);
972 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
973 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
974
975 check_fold_right(&evals, 7, &query, 7);
976 }
977
978 #[test]
979 fn test_1b_many_evals() {
980 const LOG_EVALS_SIZE: usize = 14;
981 let mut rng = StdRng::seed_from_u64(1);
982 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
983 .take(1 << LOG_EVALS_SIZE)
984 .collect::<Vec<_>>();
985 let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
986 .take(8)
987 .collect::<Vec<_>>();
988
989 for log_query_size in 0..10 {
990 check_fold_right(
991 &evals,
992 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
993 &query,
994 log_query_size,
995 );
996 }
997 }
998
999 #[test]
1000 fn test_8b_small_poly() {
1001 const LOG_EVALS_SIZE: usize = 5;
1002 let mut rng = StdRng::seed_from_u64(0);
1003 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1004 .take(1 << LOG_EVALS_SIZE)
1005 .collect::<Vec<_>>();
1006 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1007 .take(1 << 8)
1008 .collect::<Vec<_>>();
1009
1010 for log_query_size in 0..8 {
1011 check_fold_right(
1012 &evals,
1013 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1014 &query,
1015 log_query_size,
1016 );
1017 }
1018 }
1019
1020 #[test]
1021 fn test_8b_many_evals() {
1022 const LOG_EVALS_SIZE: usize = 13;
1023 let mut rng = StdRng::seed_from_u64(0);
1024 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1025 .take(1 << LOG_EVALS_SIZE)
1026 .collect::<Vec<_>>();
1027 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1028 .take(1 << 8)
1029 .collect::<Vec<_>>();
1030
1031 for log_query_size in 0..8 {
1032 check_fold_right(
1033 &evals,
1034 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1035 &query,
1036 log_query_size,
1037 );
1038 }
1039 }
1040
1041 #[test]
1042 fn test_lower_higher_bits() {
1043 const LOG_EVALS_SIZE: usize = 7;
1044 let mut rng = StdRng::seed_from_u64(0);
1045 let evals = [
1047 PackedBinaryField32x1b::random(&mut rng),
1048 PackedBinaryField32x1b::random(&mut rng),
1049 PackedBinaryField32x1b::random(&mut rng),
1050 PackedBinaryField32x1b::random(&mut rng),
1051 ];
1052
1053 let log_query_size = 1;
1056 let query = [PackedBinaryField32x1b::from(
1057 0b00000000_00000000_00000000_00000001u32,
1058 )];
1059 let mut out = vec![
1060 PackedBinaryField32x1b::zero();
1061 (1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1062 ];
1063 out.clear();
1064 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1065 .unwrap();
1066 unsafe {
1067 out.set_len(out.capacity());
1068 }
1069
1070 for (i, out_val) in out.iter().enumerate() {
1072 let first_lower = (evals[2 * i].0 as u16) as u32;
1073 let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1074 let expected_out = first_lower + (second_lower << 16);
1075 assert!(out_val.0 == expected_out);
1076 }
1077
1078 let query = [PackedBinaryField32x1b::from(
1081 0b00000000_00000000_00000000_00000010u32,
1082 )];
1083
1084 out.clear();
1085 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1086 .unwrap();
1087 unsafe {
1088 out.set_len(out.capacity());
1089 }
1090
1091 for (i, out_val) in out.iter().enumerate() {
1093 let first_higher = evals[2 * i].0 >> 16;
1094 let second_higher = evals[2 * i + 1].0 >> 16;
1095 let expected_out = first_higher + (second_higher << 16);
1096 assert!(out_val.0 == expected_out);
1097 }
1098 }
1099
1100 fn fold_left_reference<P, PE>(
1101 evals: &[P],
1102 log_evals_size: usize,
1103 query: &[PE],
1104 log_query_size: usize,
1105 out: &mut [PE],
1106 ) where
1107 P: PackedField,
1108 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1109 {
1110 for i in 0..1 << (log_evals_size - log_query_size) {
1111 let mut result = PE::Scalar::ZERO;
1112 for j in 0..1 << log_query_size {
1113 result += get_packed_slice(query, j)
1114 * get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1115 }
1116
1117 set_packed_slice(out, i, result);
1118 }
1119 }
1120
1121 fn check_fold_left<P, PE>(
1122 evals: &[P],
1123 log_evals_size: usize,
1124 query: &[PE],
1125 log_query_size: usize,
1126 ) where
1127 P: PackedField,
1128 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1129 {
1130 let mut reference_out =
1131 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1132
1133 let mut out = reference_out.clone();
1134 out.clear();
1135 fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1136 unsafe {
1137 out.set_len(out.capacity());
1138 }
1139
1140 fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1141
1142 for i in 0..1 << (log_evals_size - log_query_size) {
1143 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1144 }
1145 }
1146
1147 #[test]
1148 fn test_fold_left_1b_small_poly_query_log_size_0() {
1149 let mut rng = StdRng::seed_from_u64(0);
1150 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1151 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1152
1153 check_fold_left(&evals, 0, &query, 0);
1154 }
1155
1156 #[test]
1157 fn test_fold_left_1b_small_poly_query_log_size_1() {
1158 let mut rng = StdRng::seed_from_u64(0);
1159 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1160 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1161
1162 check_fold_left(&evals, 2, &query, 1);
1163 }
1164
1165 #[test]
1166 fn test_fold_left_1b_small_poly_query_log_size_7() {
1167 let mut rng = StdRng::seed_from_u64(0);
1168 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1169 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1170
1171 check_fold_left(&evals, 7, &query, 7);
1172 }
1173
1174 #[test]
1175 fn test_fold_left_1b_many_evals() {
1176 const LOG_EVALS_SIZE: usize = 14;
1177 let mut rng = StdRng::seed_from_u64(1);
1178 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1179 .take(1 << LOG_EVALS_SIZE)
1180 .collect::<Vec<_>>();
1181 let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1182
1183 for log_query_size in 0..10 {
1184 check_fold_left(
1185 &evals,
1186 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1187 &query,
1188 log_query_size,
1189 );
1190 }
1191 }
1192
1193 type B128bOptimal = ArchOptimalType<BinaryField128b>;
1194
1195 #[test]
1196 fn test_fold_left_1b_128b_optimal() {
1197 const LOG_EVALS_SIZE: usize = 14;
1198 let mut rng = StdRng::seed_from_u64(0);
1199 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1200 .take(1 << LOG_EVALS_SIZE)
1201 .collect::<Vec<_>>();
1202 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1203 .take(1 << (10 - B128bOptimal::LOG_WIDTH))
1204 .collect::<Vec<_>>();
1205
1206 for log_query_size in 0..10 {
1207 check_fold_left(
1208 &evals,
1209 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1210 &query,
1211 log_query_size,
1212 );
1213 }
1214 }
1215
1216 #[test]
1217 fn test_fold_left_128b_128b() {
1218 const LOG_EVALS_SIZE: usize = 14;
1219 let mut rng = StdRng::seed_from_u64(0);
1220 let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1221 .take(1 << LOG_EVALS_SIZE)
1222 .collect::<Vec<_>>();
1223 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1224 .take(1 << 10)
1225 .collect::<Vec<_>>();
1226
1227 for log_query_size in 0..10 {
1228 check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1229 }
1230 }
1231
1232 #[test]
1233 fn test_fold_left_lerp_inplace_conforms_reference() {
1234 const LOG_EVALS_SIZE: usize = 14;
1235 let mut rng = StdRng::seed_from_u64(0);
1236 let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1237 .take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1238 .collect::<Vec<_>>();
1239 let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1240 let mut query =
1241 vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1242 set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1243 set_packed_slice(&mut query, 1, lerp_query);
1244
1245 for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1246 let mut out = vec![
1247 MaybeUninit::uninit();
1248 1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1249 ];
1250 fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1251 fold_left_lerp_inplace(
1252 &mut evals,
1253 1 << log_evals_size,
1254 Field::ZERO,
1255 log_evals_size,
1256 lerp_query,
1257 )
1258 .unwrap();
1259
1260 for (out, &inplace) in izip!(&out, &evals) {
1261 unsafe {
1262 assert_eq!(out.assume_init(), inplace);
1263 }
1264 }
1265 }
1266 }
1267
1268 #[test]
1269 fn test_check_fold_arguments_valid() {
1270 let evals = vec![PackedBinaryField128x1b::default(); 8];
1271 let query = vec![PackedBinaryField128x1b::default(); 4];
1272 let out = vec![PackedBinaryField128x1b::default(); 4];
1273
1274 let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1276 assert!(result.is_ok());
1277 }
1278
1279 #[test]
1280 fn test_check_fold_arguments_invalid_query_size() {
1281 let evals = vec![PackedBinaryField128x1b::default(); 8];
1282 let query = vec![PackedBinaryField128x1b::default(); 4];
1283 let out = vec![PackedBinaryField128x1b::default(); 4];
1284
1285 let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1287 assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1288 }
1289}