1use core::slice;
4use std::{any::TypeId, cmp::min, mem::MaybeUninit};
5
6use binius_field::{
7 AESTowerField128b, BinaryField1b, BinaryField128b, BinaryField128bPolyval, ExtensionField,
8 Field, PackedField,
9 arch::{ArchOptimal, OptimalUnderlier},
10 byte_iteration::{
11 ByteIteratorCallback, can_iterate_bytes, create_partial_sums_lookup_tables,
12 is_sequential_bytes, iterate_bytes,
13 },
14 packed::{
15 PackedSlice, get_packed_slice, get_packed_slice_unchecked, set_packed_slice,
16 set_packed_slice_unchecked,
17 },
18 underlier::{UnderlierWithBitOps, WithUnderlier},
19};
20use binius_utils::{bail, mem::slice_assume_init_mut};
21use bytemuck::fill_zeroes;
22use itertools::izip;
23use lazy_static::lazy_static;
24
25use crate::Error;
26
27pub fn zero_pad<P, PE>(
28 evals: &[P],
29 log_evals_size: usize,
30 log_new_vals: usize,
31 start_index: usize,
32 nonzero_index: usize,
33 out: &mut [PE],
34) -> Result<(), Error>
35where
36 P: PackedField,
37 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
38{
39 if start_index > log_evals_size {
40 bail!(Error::IncorrectStartIndex {
41 expected: log_evals_size
42 });
43 }
44
45 let extra_vars = log_new_vals - log_evals_size;
46 if nonzero_index > 1 << extra_vars {
47 bail!(Error::IncorrectNonZeroIndex {
48 expected: 1 << extra_vars
49 });
50 }
51
52 let block_size = 1 << start_index;
53 let nb_elts_per_row = 1 << (start_index + extra_vars);
54 let length = 1 << start_index;
55
56 let start_index_in_row = nonzero_index * block_size;
57 for (outer_index, packed_result_eval) in out.iter_mut().enumerate() {
58 let outer_word_start = outer_index << PE::LOG_WIDTH;
60 for inner_index in 0..min(PE::WIDTH, 1 << log_evals_size) {
61 let outer_scalar_index = outer_word_start + inner_index;
63 let inner_col = outer_scalar_index % nb_elts_per_row;
66 let inner_row = outer_scalar_index / nb_elts_per_row;
68
69 if inner_col >= start_index_in_row && inner_col < start_index_in_row + block_size {
70 let eval_index = inner_row * length + inner_col - start_index_in_row;
71 let result_eval = get_packed_slice(evals, eval_index).into();
72 unsafe {
73 packed_result_eval.set_unchecked(inner_index, result_eval);
74 }
75 }
76 }
77 }
78 Ok(())
79}
80
81pub fn fold_right<P, PE>(
89 evals: &[P],
90 log_evals_size: usize,
91 query: &[PE],
92 log_query_size: usize,
93 out: &mut [PE],
94) -> Result<(), Error>
95where
96 P: PackedField,
97 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
98{
99 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
100
101 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
103 && fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
104 {
105 return Ok(());
106 }
107
108 let is_lerp = log_query_size == 1
110 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
111
112 if is_lerp {
113 let lerp_query = get_packed_slice(query, 1);
114 fold_right_lerp(evals, 1 << log_evals_size, lerp_query, P::Scalar::ZERO, out)?;
115 } else {
116 fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
117 }
118
119 Ok(())
120}
121
122pub fn fold_left<P, PE>(
133 evals: &[P],
134 log_evals_size: usize,
135 query: &[PE],
136 log_query_size: usize,
137 out: &mut [MaybeUninit<PE>],
138) -> Result<(), Error>
139where
140 P: PackedField,
141 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
142{
143 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
144
145 if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
146 && fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
147 {
148 return Ok(());
149 }
150
151 let is_lerp = log_query_size == 1
156 && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE
157 && TypeId::of::<P>() == TypeId::of::<PE>();
158
159 if is_lerp {
160 let lerp_query = get_packed_slice(query, 1);
161 let out_p =
163 unsafe { std::mem::transmute::<&mut [MaybeUninit<PE>], &mut [MaybeUninit<P>]>(out) };
164
165 let lerp_query_p = lerp_query.try_into().ok().expect("P == PE");
166 fold_left_lerp(
167 evals,
168 1 << log_evals_size,
169 P::Scalar::ZERO,
170 log_evals_size,
171 lerp_query_p,
172 out_p,
173 )?;
174 } else {
175 fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
176 }
177
178 Ok(())
179}
180
181pub fn fold_middle<P, PE>(
190 evals: &[P],
191 log_evals_size: usize,
192 query: &[PE],
193 log_query_size: usize,
194 start_index: usize,
195 out: &mut [MaybeUninit<PE>],
196) -> Result<(), Error>
197where
198 P: PackedField,
199 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
200{
201 check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
202
203 if log_evals_size < log_query_size + start_index {
204 bail!(Error::IncorrectStartIndex {
205 expected: log_evals_size
206 });
207 }
208
209 let lower_indices_size = 1 << start_index;
210 let new_n_vars = log_evals_size - log_query_size;
211
212 out.iter_mut()
213 .enumerate()
214 .for_each(|(outer_index, out_val)| {
215 let mut res = PE::default();
216 let outer_word_start = outer_index << PE::LOG_WIDTH;
218 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
219 let outer_scalar_index = outer_word_start + inner_index;
221 let inner_i = outer_scalar_index % lower_indices_size;
224 let inner_j = outer_scalar_index / lower_indices_size;
226 res.set(
227 inner_index,
228 PackedField::iter_slice(query)
229 .take(1 << log_query_size)
230 .enumerate()
231 .map(|(query_index, basis_eval)| {
232 let offset = inner_j << (start_index + log_query_size) | inner_i;
235 let eval_index = offset + (query_index << start_index);
236 let subpoly_eval_i = get_packed_slice(evals, eval_index);
237 basis_eval * subpoly_eval_i
238 })
239 .sum(),
240 );
241 }
242 out_val.write(res);
243 });
244
245 Ok(())
246}
247
248#[inline]
249fn check_fold_arguments<P, PE, POut>(
250 evals: &[P],
251 log_evals_size: usize,
252 query: &[PE],
253 log_query_size: usize,
254 out: &[POut],
255) -> Result<(), Error>
256where
257 P: PackedField,
258 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
259{
260 if log_evals_size < log_query_size {
261 bail!(Error::IncorrectQuerySize {
262 expected: log_evals_size,
263 actual: log_query_size,
264 });
265 }
266
267 if P::WIDTH * evals.len() < 1 << log_evals_size {
268 bail!(Error::IncorrectArgumentLength {
269 arg: "evals".into(),
270 expected: 1 << log_evals_size
271 });
272 }
273
274 if PE::WIDTH * query.len() < 1 << log_query_size {
275 bail!(Error::IncorrectArgumentLength {
276 arg: "query".into(),
277 expected: 1 << log_query_size
278 });
279 }
280
281 if PE::WIDTH * out.len() < 1 << (log_evals_size - log_query_size) {
282 bail!(Error::IncorrectOutputPolynomialSize {
283 expected: 1 << (log_evals_size - log_query_size)
284 });
285 }
286
287 Ok(())
288}
289
290#[inline]
291fn check_right_lerp_fold_arguments<P, PE, POut>(
292 evals: &[P],
293 evals_size: usize,
294 out: &[POut],
295) -> Result<(), Error>
296where
297 P: PackedField,
298 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
299{
300 if P::WIDTH * evals.len() < evals_size {
301 bail!(Error::IncorrectArgumentLength {
302 arg: "evals".into(),
303 expected: evals_size
304 });
305 }
306
307 if PE::WIDTH * out.len() * 2 < evals_size {
308 bail!(Error::IncorrectOutputPolynomialSize {
309 expected: evals_size.div_ceil(2)
310 });
311 }
312
313 Ok(())
314}
315
316#[inline]
317fn check_left_lerp_fold_arguments<P, PE, POut>(
318 evals: &[P],
319 non_const_prefix: usize,
320 log_evals_size: usize,
321 out: &[POut],
322) -> Result<(), Error>
323where
324 P: PackedField,
325 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
326{
327 if log_evals_size == 0 {
328 bail!(Error::IncorrectQuerySize {
329 expected: 1,
330 actual: log_evals_size,
331 });
332 }
333
334 if non_const_prefix > 1 << log_evals_size {
335 bail!(Error::IncorrectNonzeroScalarPrefix {
336 expected: 1 << log_evals_size,
337 });
338 }
339
340 if P::WIDTH * evals.len() < non_const_prefix {
341 bail!(Error::IncorrectArgumentLength {
342 arg: "evals".into(),
343 expected: non_const_prefix,
344 });
345 }
346
347 let folded_non_const_prefix = non_const_prefix.min(1 << (log_evals_size - 1));
348
349 if PE::WIDTH * out.len() < folded_non_const_prefix {
350 bail!(Error::IncorrectOutputPolynomialSize {
351 expected: folded_non_const_prefix,
352 });
353 }
354
355 Ok(())
356}
357
358fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
360 evals: &[P],
361 query: &[PE],
362 out: &mut [PE],
363) -> bool
364where
365 P: PackedField,
366 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
367{
368 if LOG_QUERY_SIZE >= 3 || (P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH) {
369 return false;
370 }
371
372 let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
374 .map(|i| {
375 let mut result = PE::Scalar::ZERO;
376 for j in 0..1 << LOG_QUERY_SIZE {
377 if i >> j & 1 == 1 {
378 result += get_packed_slice(query, j);
379 }
380 }
381 result
382 })
383 .collect::<Vec<_>>();
384
385 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
386 out: &'a mut [PE],
387 cached_table: &'a [PE::Scalar],
388 }
389
390 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
391 for Callback<'_, PE, LOG_QUERY_SIZE>
392 {
393 #[inline(always)]
394 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
395 let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
396 let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
397 let mut current_index = 0;
398 for byte in iterator {
399 for k in 0..values_in_byte {
400 let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
401 unsafe {
403 set_packed_slice_unchecked(
404 self.out,
405 current_index + k,
406 self.cached_table[index as usize],
407 );
408 }
409 }
410
411 current_index += values_in_byte;
412 }
413 }
414 }
415
416 let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
417 out,
418 cached_table: &cached_table,
419 };
420
421 iterate_bytes(evals, &mut callback);
422
423 true
424}
425
426fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
428 evals: &[P],
429 query: &[PE],
430 out: &mut [PE],
431) -> bool
432where
433 P: PackedField,
434 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
435{
436 if LOG_QUERY_SIZE < 3 {
437 return false;
438 }
439
440 let cached_tables =
441 create_partial_sums_lookup_tables(PackedSlice::new_with_len(query, 1 << LOG_QUERY_SIZE));
442
443 struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
444 out: &'a mut [PE],
445 cached_tables: &'a [PE::Scalar],
446 }
447
448 impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
449 for Callback<'_, PE, LOG_QUERY_SIZE>
450 {
451 #[inline(always)]
452 fn call(&mut self, iterator: impl Iterator<Item = u8>) {
453 let log_tables_count = LOG_QUERY_SIZE - 3;
454 let tables_count = 1 << log_tables_count;
455 let mut current_index = 0;
456 let mut current_table = 0;
457 let mut current_value = PE::Scalar::ZERO;
458 for byte in iterator {
459 current_value += self.cached_tables[(current_table << 8) + byte as usize];
460 current_table += 1;
461
462 if current_table == tables_count {
463 unsafe {
465 set_packed_slice_unchecked(self.out, current_index, current_value);
466 }
467 current_index += 1;
468 current_table = 0;
469 current_value = PE::Scalar::ZERO;
470 }
471 }
472 }
473 }
474
475 let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
476 out,
477 cached_tables: &cached_tables,
478 };
479
480 iterate_bytes(evals, &mut callback);
481
482 true
483}
484
485fn fold_right_1bit_evals<P, PE>(
489 evals: &[P],
490 log_evals_size: usize,
491 query: &[PE],
492 log_query_size: usize,
493 out: &mut [PE],
494) -> bool
495where
496 P: PackedField,
497 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
498{
499 if log_evals_size < P::LOG_WIDTH || !can_iterate_bytes::<P>() {
500 return false;
501 }
502
503 match log_query_size {
506 0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
507 1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
508 2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
509 3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
510 4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
511 5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
512 6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
513 7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
514 _ => false,
515 }
516}
517
518pub fn fold_right_lerp<P, PE>(
529 evals: &[P],
530 evals_size: usize,
531 lerp_query: PE::Scalar,
532 suffix_eval: P::Scalar,
533 out: &mut [PE],
534) -> Result<(), Error>
535where
536 P: PackedField,
537 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
538{
539 check_right_lerp_fold_arguments::<_, PE, _>(evals, evals_size, out)?;
540
541 let folded_evals_size = evals_size >> 1;
542 out[..folded_evals_size.div_ceil(PE::WIDTH)]
543 .iter_mut()
544 .enumerate()
545 .for_each(|(i, packed_result_eval)| {
546 for j in 0..min(PE::WIDTH, folded_evals_size - (i << PE::LOG_WIDTH)) {
547 let index = (i << PE::LOG_WIDTH) | j;
548
549 let (eval0, eval1) = unsafe {
550 (
551 get_packed_slice_unchecked(evals, index << 1),
552 get_packed_slice_unchecked(evals, (index << 1) | 1),
553 )
554 };
555
556 let result_eval =
557 PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
558
559 unsafe {
561 packed_result_eval.set_unchecked(j, result_eval);
562 }
563 }
564 });
565
566 if evals_size % 2 == 1 {
567 let eval0 = get_packed_slice(evals, folded_evals_size << 1);
568 set_packed_slice(
569 out,
570 folded_evals_size,
571 PE::Scalar::from(suffix_eval - eval0) * lerp_query + PE::Scalar::from(eval0),
572 );
573 }
574
575 Ok(())
576}
577
578pub fn fold_left_lerp<P>(
591 evals: &[P],
592 non_const_prefix: usize,
593 suffix_eval: P::Scalar,
594 log_evals_size: usize,
595 lerp_query: P::Scalar,
596 out: &mut [MaybeUninit<P>],
597) -> Result<(), Error>
598where
599 P: PackedField,
600{
601 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, out)?;
602
603 if log_evals_size > P::LOG_WIDTH {
604 let pivot = non_const_prefix
605 .saturating_sub(1 << (log_evals_size - 1))
606 .div_ceil(P::WIDTH);
607
608 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
609 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
610
611 if pivot > 0 {
612 let (evals_0, evals_1) = evals.split_at(packed_len);
613 for (out, eval_0, eval_1) in izip!(&mut out[..pivot], evals_0, evals_1) {
614 out.write(*eval_0 + (*eval_1 - *eval_0) * lerp_query);
615 }
616 }
617
618 let broadcast_suffix_eval = P::broadcast(suffix_eval);
619 for (out, eval) in izip!(&mut out[pivot..upper_bound], &evals[pivot..]) {
620 out.write(*eval + (broadcast_suffix_eval - *eval) * lerp_query);
621 }
622
623 for out in &mut out[upper_bound..] {
624 out.write(P::zero());
625 }
626 } else if non_const_prefix > 0 {
627 let only_packed = *evals.first().expect("log_evals_size > 0");
628 let mut folded = P::zero();
629
630 for i in 0..1 << (log_evals_size - 1) {
631 let eval_0 = only_packed.get(i);
632 let eval_1 = only_packed.get(i | 1 << (log_evals_size - 1));
633 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
634 }
635
636 out.first_mut().expect("log_evals_size > 0").write(folded);
637 }
638
639 Ok(())
640}
641
642pub fn fold_left_lerp_inplace<P>(
649 evals: &mut Vec<P>,
650 non_const_prefix: usize,
651 suffix_eval: P::Scalar,
652 log_evals_size: usize,
653 lerp_query: P::Scalar,
654) -> Result<(), Error>
655where
656 P: PackedField,
657{
658 check_left_lerp_fold_arguments::<_, P, _>(evals, non_const_prefix, log_evals_size, evals)?;
659
660 if log_evals_size > P::LOG_WIDTH {
661 let pivot = non_const_prefix
662 .saturating_sub(1 << (log_evals_size - 1))
663 .div_ceil(P::WIDTH);
664
665 let packed_len = 1 << (log_evals_size - 1 - P::LOG_WIDTH);
666 let upper_bound = non_const_prefix.div_ceil(P::WIDTH).min(packed_len);
667
668 if pivot > 0 {
669 let (evals_0, evals_1) = evals.split_at_mut(packed_len);
670 for (eval_0, eval_1) in izip!(&mut evals_0[..pivot], evals_1) {
671 *eval_0 += (*eval_1 - *eval_0) * lerp_query;
672 }
673 }
674
675 let broadcast_suffix_eval = P::broadcast(suffix_eval);
676 for eval in &mut evals[pivot..upper_bound] {
677 *eval += (broadcast_suffix_eval - *eval) * lerp_query;
678 }
679
680 evals.truncate(upper_bound);
681 } else if non_const_prefix > 0 {
682 let only_packed = evals.first_mut().expect("log_evals_size > 0");
683 let mut folded = P::zero();
684 let half_size = 1 << (log_evals_size - 1);
685
686 for i in 0..half_size {
687 let eval_0 = only_packed.get(i);
688 let eval_1 = only_packed.get(i | half_size);
689 folded.set(i, eval_0 + lerp_query * (eval_1 - eval_0));
690 }
691
692 *only_packed = folded;
693 }
694
695 Ok(())
696}
697
698fn fold_right_fallback<P, PE>(
700 evals: &[P],
701 log_evals_size: usize,
702 query: &[PE],
703 log_query_size: usize,
704 out: &mut [PE],
705) where
706 P: PackedField,
707 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
708{
709 for (k, packed_result_eval) in out.iter_mut().enumerate() {
710 for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
711 let index = (k << PE::LOG_WIDTH) | j;
712
713 let offset = index << log_query_size;
714
715 let mut result_eval = PE::Scalar::ZERO;
716 for (t, query_expansion) in PackedField::iter_slice(query)
717 .take(1 << log_query_size)
718 .enumerate()
719 {
720 result_eval += query_expansion * get_packed_slice(evals, t + offset);
721 }
722
723 unsafe {
725 packed_result_eval.set_unchecked(j, result_eval);
726 }
727 }
728 }
729}
730
731type ArchOptimalType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
732
733#[inline(always)]
734fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
735 TypeId::of::<ArchOptimalType<F>>()
736}
737
738fn fold_left_1b_128b<P, PE>(
745 evals: &[P],
746 log_evals_size: usize,
747 query: &[PE],
748 log_query_size: usize,
749 out: &mut [MaybeUninit<PE>],
750) -> bool
751where
752 P: PackedField,
753 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
754{
755 if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
756 return false;
757 }
758
759 let log_row_size = log_evals_size - log_query_size;
760 if log_row_size < 3 {
761 return false;
762 }
763
764 if PE::LOG_WIDTH > 3 {
765 return false;
766 }
767
768 let evals_u8: &[u8] = unsafe {
770 std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
771 };
772
773 #[inline]
776 fn try_run_specialization<PE, F>(
777 lookup_table: &[OptimalUnderlier],
778 evals_u8: &[u8],
779 log_evals_size: usize,
780 query: &[PE],
781 log_query_size: usize,
782 out: &mut [MaybeUninit<PE>],
783 ) -> bool
784 where
785 PE: PackedField,
786 F: ArchOptimal,
787 {
788 if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
789 let query = cast_same_type_slice::<_, ArchOptimalType<F>>(query);
790 let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimalType<F>>>(out);
791
792 fold_left_1b_128b_impl(
793 lookup_table,
794 evals_u8,
795 log_evals_size,
796 query,
797 log_query_size,
798 out,
799 );
800 true
801 } else {
802 false
803 }
804 }
805
806 let lookup_table = &*LOOKUP_TABLE;
807 try_run_specialization::<_, BinaryField128b>(
808 lookup_table,
809 evals_u8,
810 log_evals_size,
811 query,
812 log_query_size,
813 out,
814 ) || try_run_specialization::<_, AESTowerField128b>(
815 lookup_table,
816 evals_u8,
817 log_evals_size,
818 query,
819 log_query_size,
820 out,
821 ) || try_run_specialization::<_, BinaryField128bPolyval>(
822 lookup_table,
823 evals_u8,
824 log_evals_size,
825 query,
826 log_query_size,
827 out,
828 )
829}
830
831#[inline(always)]
833fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
834 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
835 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
837}
838
839#[inline(always)]
841fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
842 assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
843 unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
845}
846
847fn init_lookup_table_width<U>() -> Vec<U>
852where
853 U: UnderlierWithBitOps + From<u128>,
854{
855 let items_b128 = U::BITS / u128::BITS as usize;
856 assert!(items_b128 <= 8);
857 let items_in_byte = 8 / items_b128;
858
859 let mut result = Vec::with_capacity(256 * items_in_byte);
860 for i in 0..256 {
861 for j in 0..items_in_byte {
862 let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
863 let mut value = U::ZERO;
864 for k in 0..items_b128 {
865 if (bits >> k) & 1 == 1 {
866 unsafe {
867 value.set_subvalue(k, u128::ONES);
868 }
869 }
870 }
871 result.push(value);
872 }
873 }
874
875 result
876}
877
878lazy_static! {
879 static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
880}
881
882#[inline]
883fn fold_left_1b_128b_impl<PE, U>(
884 lookup_table: &[U],
885 evals: &[u8],
886 log_evals_size: usize,
887 query: &[PE],
888 log_query_size: usize,
889 out: &mut [MaybeUninit<PE>],
890) where
891 PE: PackedField + WithUnderlier<Underlier = U>,
892 U: UnderlierWithBitOps,
893{
894 let out = unsafe { slice_assume_init_mut(out) };
895 fill_zeroes(out);
896
897 let items_in_byte = 8 / PE::WIDTH;
898 let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
899 for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
900 let query_val = PE::broadcast(query_val).to_underlier();
901 for (byte_index, byte) in row_bytes.iter().enumerate() {
902 let mask_offset = *byte as usize * items_in_byte;
903 let out_offset = byte_index * items_in_byte;
904 for i in 0..items_in_byte {
905 let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
906 let multiplied = query_val & *mask;
907 let out = unsafe { out.get_unchecked_mut(out_offset + i) };
908 *out += PE::from_underlier(multiplied);
909 }
910 }
911 }
912}
913
914fn fold_left_fallback<P, PE>(
915 evals: &[P],
916 log_evals_size: usize,
917 query: &[PE],
918 log_query_size: usize,
919 out: &mut [MaybeUninit<PE>],
920) where
921 P: PackedField,
922 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
923{
924 let new_n_vars = log_evals_size - log_query_size;
925
926 out.iter_mut()
927 .enumerate()
928 .for_each(|(outer_index, out_val)| {
929 let mut res = PE::default();
930 for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
931 res.set(
932 inner_index,
933 PackedField::iter_slice(query)
934 .take(1 << log_query_size)
935 .enumerate()
936 .map(|(query_index, basis_eval)| {
937 let eval_index = (query_index << new_n_vars)
938 | (outer_index << PE::LOG_WIDTH)
939 | inner_index;
940 let subpoly_eval_i = get_packed_slice(evals, eval_index);
941 basis_eval * subpoly_eval_i
942 })
943 .sum(),
944 );
945 }
946
947 out_val.write(res);
948 });
949}
950
951#[cfg(test)]
952mod tests {
953 use std::iter::repeat_with;
954
955 use binius_field::{
956 PackedBinaryField8x1b, PackedBinaryField16x8b, PackedBinaryField16x32b,
957 PackedBinaryField32x1b, PackedBinaryField64x8b, PackedBinaryField128x1b,
958 PackedBinaryField512x1b, packed::set_packed_slice,
959 };
960 use rand::{SeedableRng, rngs::StdRng};
961
962 use super::*;
963
964 fn fold_right_reference<P, PE>(
965 evals: &[P],
966 log_evals_size: usize,
967 query: &[PE],
968 log_query_size: usize,
969 out: &mut [PE],
970 ) where
971 P: PackedField,
972 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
973 {
974 for i in 0..1 << (log_evals_size - log_query_size) {
975 let mut result = PE::Scalar::ZERO;
976 for j in 0..1 << log_query_size {
977 result +=
978 get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
979 }
980
981 set_packed_slice(out, i, result);
982 }
983 }
984
985 fn check_fold_right<P, PE>(
986 evals: &[P],
987 log_evals_size: usize,
988 query: &[PE],
989 log_query_size: usize,
990 ) where
991 P: PackedField,
992 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
993 {
994 let mut reference_out =
995 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
996 let mut out = reference_out.clone();
997
998 fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
999 fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1000
1001 for i in 0..1 << (log_evals_size - log_query_size) {
1002 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1003 }
1004 }
1005
1006 #[test]
1007 fn test_1b_small_poly_query_log_size_0() {
1008 let mut rng = StdRng::seed_from_u64(0);
1009 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1010 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1011
1012 check_fold_right(&evals, 0, &query, 0);
1013 }
1014
1015 #[test]
1016 fn test_1b_small_poly_query_log_size_1() {
1017 let mut rng = StdRng::seed_from_u64(0);
1018 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1019 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1020
1021 check_fold_right(&evals, 2, &query, 1);
1022 }
1023
1024 #[test]
1025 fn test_1b_small_poly_query_log_size_7() {
1026 let mut rng = StdRng::seed_from_u64(0);
1027 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1028 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1029
1030 check_fold_right(&evals, 7, &query, 7);
1031 }
1032
1033 #[test]
1034 fn test_1b_many_evals() {
1035 const LOG_EVALS_SIZE: usize = 14;
1036 let mut rng = StdRng::seed_from_u64(1);
1037 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1038 .take(1 << LOG_EVALS_SIZE)
1039 .collect::<Vec<_>>();
1040 let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
1041 .take(8)
1042 .collect::<Vec<_>>();
1043
1044 for log_query_size in 0..10 {
1045 check_fold_right(
1046 &evals,
1047 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1048 &query,
1049 log_query_size,
1050 );
1051 }
1052 }
1053
1054 #[test]
1055 fn test_8b_small_poly() {
1056 const LOG_EVALS_SIZE: usize = 5;
1057 let mut rng = StdRng::seed_from_u64(0);
1058 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1059 .take(1 << LOG_EVALS_SIZE)
1060 .collect::<Vec<_>>();
1061 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1062 .take(1 << 8)
1063 .collect::<Vec<_>>();
1064
1065 for log_query_size in 0..8 {
1066 check_fold_right(
1067 &evals,
1068 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1069 &query,
1070 log_query_size,
1071 );
1072 }
1073 }
1074
1075 #[test]
1076 fn test_8b_many_evals() {
1077 const LOG_EVALS_SIZE: usize = 13;
1078 let mut rng = StdRng::seed_from_u64(0);
1079 let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
1080 .take(1 << LOG_EVALS_SIZE)
1081 .collect::<Vec<_>>();
1082 let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
1083 .take(1 << 8)
1084 .collect::<Vec<_>>();
1085
1086 for log_query_size in 0..8 {
1087 check_fold_right(
1088 &evals,
1089 LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
1090 &query,
1091 log_query_size,
1092 );
1093 }
1094 }
1095
1096 #[test]
1097 fn test_lower_higher_bits() {
1098 const LOG_EVALS_SIZE: usize = 7;
1099 let mut rng = StdRng::seed_from_u64(0);
1100 let evals = [
1102 PackedBinaryField32x1b::random(&mut rng),
1103 PackedBinaryField32x1b::random(&mut rng),
1104 PackedBinaryField32x1b::random(&mut rng),
1105 PackedBinaryField32x1b::random(&mut rng),
1106 ];
1107
1108 let log_query_size = 1;
1111 let query = [PackedBinaryField32x1b::from(
1112 0b00000000_00000000_00000000_00000001u32,
1113 )];
1114 let mut out = vec![
1115 PackedBinaryField32x1b::zero();
1116 (1 << (LOG_EVALS_SIZE - log_query_size)) / PackedBinaryField32x1b::WIDTH
1117 ];
1118 out.clear();
1119 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1120 .unwrap();
1121 unsafe {
1122 out.set_len(out.capacity());
1123 }
1124
1125 for (i, out_val) in out.iter().enumerate() {
1127 let first_lower = (evals[2 * i].0 as u16) as u32;
1128 let second_lower = (evals[2 * i + 1].0 as u16) as u32;
1129 let expected_out = first_lower + (second_lower << 16);
1130 assert!(out_val.0 == expected_out);
1131 }
1132
1133 let query = [PackedBinaryField32x1b::from(
1136 0b00000000_00000000_00000000_00000010u32,
1137 )];
1138
1139 out.clear();
1140 fold_middle(&evals, LOG_EVALS_SIZE, &query, log_query_size, 4, out.spare_capacity_mut())
1141 .unwrap();
1142 unsafe {
1143 out.set_len(out.capacity());
1144 }
1145
1146 for (i, out_val) in out.iter().enumerate() {
1148 let first_higher = evals[2 * i].0 >> 16;
1149 let second_higher = evals[2 * i + 1].0 >> 16;
1150 let expected_out = first_higher + (second_higher << 16);
1151 assert!(out_val.0 == expected_out);
1152 }
1153 }
1154
1155 #[test]
1156 fn test_zeropad_bits() {
1157 const LOG_EVALS_SIZE: usize = 5;
1158 let mut rng = StdRng::seed_from_u64(0);
1159 let evals = [
1161 PackedBinaryField8x1b::random(&mut rng),
1162 PackedBinaryField8x1b::random(&mut rng),
1163 PackedBinaryField8x1b::random(&mut rng),
1164 PackedBinaryField8x1b::random(&mut rng),
1165 ];
1166
1167 let num_extra_vars = 2;
1169 let start_index = 3;
1170
1171 for nonzero_index in 0..1 << num_extra_vars {
1172 let mut out =
1173 vec![
1174 PackedBinaryField8x1b::zero();
1175 1usize << (LOG_EVALS_SIZE + num_extra_vars - PackedBinaryField8x1b::LOG_WIDTH)
1176 ];
1177 zero_pad(
1178 &evals,
1179 LOG_EVALS_SIZE,
1180 LOG_EVALS_SIZE + num_extra_vars,
1181 start_index,
1182 nonzero_index,
1183 &mut out,
1184 )
1185 .unwrap();
1186
1187 for (i, out_val) in out.iter().enumerate() {
1191 let expansion = 1 << num_extra_vars;
1192 if i % expansion == nonzero_index {
1193 assert!(out_val.0 == evals[i / expansion].0);
1194 } else {
1195 assert!(out_val.0 == 0);
1196 }
1197 }
1198 }
1199 }
1200
1201 fn fold_left_reference<P, PE>(
1202 evals: &[P],
1203 log_evals_size: usize,
1204 query: &[PE],
1205 log_query_size: usize,
1206 out: &mut [PE],
1207 ) where
1208 P: PackedField,
1209 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1210 {
1211 for i in 0..1 << (log_evals_size - log_query_size) {
1212 let mut result = PE::Scalar::ZERO;
1213 for j in 0..1 << log_query_size {
1214 result += get_packed_slice(query, j)
1215 * get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
1216 }
1217
1218 set_packed_slice(out, i, result);
1219 }
1220 }
1221
1222 fn check_fold_left<P, PE>(
1223 evals: &[P],
1224 log_evals_size: usize,
1225 query: &[PE],
1226 log_query_size: usize,
1227 ) where
1228 P: PackedField,
1229 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
1230 {
1231 let mut reference_out =
1232 vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
1233
1234 let mut out = reference_out.clone();
1235 out.clear();
1236 fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
1237 unsafe {
1238 out.set_len(out.capacity());
1239 }
1240
1241 fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
1242
1243 for i in 0..1 << (log_evals_size - log_query_size) {
1244 assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
1245 }
1246 }
1247
1248 #[test]
1249 fn test_fold_left_1b_small_poly_query_log_size_0() {
1250 let mut rng = StdRng::seed_from_u64(0);
1251 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1252 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1253
1254 check_fold_left(&evals, 0, &query, 0);
1255 }
1256
1257 #[test]
1258 fn test_fold_left_1b_small_poly_query_log_size_1() {
1259 let mut rng = StdRng::seed_from_u64(0);
1260 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1261 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1262
1263 check_fold_left(&evals, 2, &query, 1);
1264 }
1265
1266 #[test]
1267 fn test_fold_left_1b_small_poly_query_log_size_7() {
1268 let mut rng = StdRng::seed_from_u64(0);
1269 let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
1270 let query = vec![PackedBinaryField128x1b::random(&mut rng)];
1271
1272 check_fold_left(&evals, 7, &query, 7);
1273 }
1274
1275 #[test]
1276 fn test_fold_left_1b_many_evals() {
1277 const LOG_EVALS_SIZE: usize = 14;
1278 let mut rng = StdRng::seed_from_u64(1);
1279 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1280 .take(1 << LOG_EVALS_SIZE)
1281 .collect::<Vec<_>>();
1282 let query = vec![PackedBinaryField512x1b::random(&mut rng)];
1283
1284 for log_query_size in 0..10 {
1285 check_fold_left(
1286 &evals,
1287 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1288 &query,
1289 log_query_size,
1290 );
1291 }
1292 }
1293
1294 type B128bOptimal = ArchOptimalType<BinaryField128b>;
1295
1296 #[test]
1297 fn test_fold_left_1b_128b_optimal() {
1298 const LOG_EVALS_SIZE: usize = 14;
1299 let mut rng = StdRng::seed_from_u64(0);
1300 let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
1301 .take(1 << LOG_EVALS_SIZE)
1302 .collect::<Vec<_>>();
1303 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1304 .take(1 << (10 - B128bOptimal::LOG_WIDTH))
1305 .collect::<Vec<_>>();
1306
1307 for log_query_size in 0..10 {
1308 check_fold_left(
1309 &evals,
1310 LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
1311 &query,
1312 log_query_size,
1313 );
1314 }
1315 }
1316
1317 #[test]
1318 fn test_fold_left_128b_128b() {
1319 const LOG_EVALS_SIZE: usize = 14;
1320 let mut rng = StdRng::seed_from_u64(0);
1321 let evals = repeat_with(|| B128bOptimal::random(&mut rng))
1322 .take(1 << LOG_EVALS_SIZE)
1323 .collect::<Vec<_>>();
1324 let query = repeat_with(|| B128bOptimal::random(&mut rng))
1325 .take(1 << 10)
1326 .collect::<Vec<_>>();
1327
1328 for log_query_size in 0..10 {
1329 check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
1330 }
1331 }
1332
1333 #[test]
1334 fn test_fold_left_lerp_inplace_conforms_reference() {
1335 const LOG_EVALS_SIZE: usize = 14;
1336 let mut rng = StdRng::seed_from_u64(0);
1337 let mut evals = repeat_with(|| B128bOptimal::random(&mut rng))
1338 .take(1 << LOG_EVALS_SIZE.saturating_sub(B128bOptimal::LOG_WIDTH))
1339 .collect::<Vec<_>>();
1340 let lerp_query = <BinaryField128b as Field>::random(&mut rng);
1341 let mut query =
1342 vec![B128bOptimal::default(); 1 << 1usize.saturating_sub(B128bOptimal::LOG_WIDTH)];
1343 set_packed_slice(&mut query, 0, BinaryField128b::ONE - lerp_query);
1344 set_packed_slice(&mut query, 1, lerp_query);
1345
1346 for log_evals_size in (1..=LOG_EVALS_SIZE).rev() {
1347 let mut out = vec![
1348 MaybeUninit::uninit();
1349 1 << log_evals_size.saturating_sub(B128bOptimal::LOG_WIDTH + 1)
1350 ];
1351 fold_left(&evals, log_evals_size, &query, 1, &mut out).unwrap();
1352 fold_left_lerp_inplace(
1353 &mut evals,
1354 1 << log_evals_size,
1355 Field::ZERO,
1356 log_evals_size,
1357 lerp_query,
1358 )
1359 .unwrap();
1360
1361 for (out, &inplace) in izip!(&out, &evals) {
1362 unsafe {
1363 assert_eq!(out.assume_init(), inplace);
1364 }
1365 }
1366 }
1367 }
1368
1369 #[test]
1370 fn test_check_fold_arguments_valid() {
1371 let evals = vec![PackedBinaryField128x1b::default(); 8];
1372 let query = vec![PackedBinaryField128x1b::default(); 4];
1373 let out = vec![PackedBinaryField128x1b::default(); 4];
1374
1375 let result = check_fold_arguments(&evals, 3, &query, 2, &out);
1377 assert!(result.is_ok());
1378 }
1379
1380 #[test]
1381 fn test_check_fold_arguments_invalid_query_size() {
1382 let evals = vec![PackedBinaryField128x1b::default(); 8];
1383 let query = vec![PackedBinaryField128x1b::default(); 4];
1384 let out = vec![PackedBinaryField128x1b::default(); 4];
1385
1386 let result = check_fold_arguments(&evals, 2, &query, 3, &out);
1388 assert!(matches!(result, Err(Error::IncorrectQuerySize { .. })));
1389 }
1390}