1use std::iter;
8
9use binius_field::{
10 packed::get_packed_slice_checked, Field, PackedExtension, PackedField, PackedSubfield,
11};
12use binius_math::{
13 extrapolate_lines, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
14 MultilinearQueryRef, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{izip, Either, Itertools};
20use stackalloc::stackalloc_with_iter;
21
22use crate::{
23 common::{subcube_vars_for_bits, MAX_SRC_SUBCUBE_LOG_BITS},
24 Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
25};
26
27trait SumcheckMultilinearAccess<P: PackedField> {
28 fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize>;
30
31 #[allow(clippy::too_many_arguments)]
65 fn subcube_evaluations<M: MultilinearPoly<P>>(
66 &self,
67 multilinear: &SumcheckMultilinear<P, M>,
68 subcube_vars: usize,
69 subcube_index: usize,
70 index_vars: usize,
71 tensor_query: MultilinearQueryRef<P>,
72 scratch_space: Option<&mut [P]>,
73 evals_0: &mut [P],
74 evals_1: &mut [P],
75 ) -> Result<(), Error>;
76}
77
78pub(crate) fn calculate_round_evals<FDomain, F, P, M, Evaluator, Composition>(
83 evaluation_order: EvaluationOrder,
84 n_vars: usize,
85 tensor_query: Option<MultilinearQueryRef<P>>,
86 multilinears: &[SumcheckMultilinear<P, M>],
87 evaluators: &[Evaluator],
88 finite_evaluation_points: &[FDomain],
89) -> Result<Vec<RoundEvals<F>>, Error>
90where
91 FDomain: Field,
92 F: Field,
93 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
94 M: MultilinearPoly<P> + Sync,
95 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
96 Composition: CompositionPoly<P>,
97{
98 assert!(n_vars > 0, "Computing round evaluations requires at least a single variable.");
99
100 let empty_query = MultilinearQuery::with_capacity(0);
101 let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref());
102
103 match evaluation_order {
104 EvaluationOrder::LowToHigh => calculate_round_evals_with_access(
105 LowToHighAccess,
106 n_vars,
107 tensor_query,
108 multilinears,
109 evaluators,
110 finite_evaluation_points,
111 ),
112 EvaluationOrder::HighToLow => calculate_round_evals_with_access(
113 HighToLowAccess,
114 n_vars,
115 tensor_query,
116 multilinears,
117 evaluators,
118 finite_evaluation_points,
119 ),
120 }
121}
122
123fn calculate_round_evals_with_access<FDomain, F, P, M, Evaluator, Access, Composition>(
124 access: Access,
125 n_vars: usize,
126 tensor_query: MultilinearQueryRef<P>,
127 multilinears: &[SumcheckMultilinear<P, M>],
128 evaluators: &[Evaluator],
129 nontrivial_evaluation_points: &[FDomain],
130) -> Result<Vec<RoundEvals<F>>, Error>
131where
132 FDomain: Field,
133 F: Field,
134 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
135 M: MultilinearPoly<P> + Sync,
136 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
137 Access: SumcheckMultilinearAccess<P> + Sync,
138 Composition: CompositionPoly<P>,
139{
140 let n_multilinears = multilinears.len();
141 let n_round_evals = evaluators
142 .iter()
143 .map(|evaluator| evaluator.eval_point_indices().len());
144
145 let eval_point_indices = evaluators
147 .iter()
148 .map(|evaluator| evaluator.eval_point_indices())
149 .reduce(|range1, range2| range1.start.min(range2.start)..range1.end.max(range2.end))
150 .unwrap_or(0..0);
151
152 if nontrivial_evaluation_points.len() != eval_point_indices.end.saturating_sub(3) {
154 bail!(Error::IncorrectNontrivialEvalPointsLength);
155 }
156
157 let subcube_vars = subcube_vars_for_bits::<P>(
160 MAX_SRC_SUBCUBE_LOG_BITS,
161 n_vars - 1,
162 tensor_query.n_vars(),
163 n_vars - 1,
164 );
165
166 let subcube_count_by_evaluator = evaluators
167 .iter()
168 .map(|evaluator| {
169 ((1 << (n_vars - 1)) - evaluator.const_eval_suffix()).div_ceil(1 << subcube_vars)
170 })
171 .collect::<Vec<_>>();
172
173 let mut subcube_count_by_multilinear = vec![0; n_multilinears];
174
175 for (&evaluator_subcube_count, evaluator) in izip!(&subcube_count_by_evaluator, evaluators) {
176 let used_vars = evaluator.composition().expression().vars_usage();
177
178 for (multilinear_subcube_count, usage_flag) in
179 izip!(&mut subcube_count_by_multilinear, used_vars)
180 {
181 if usage_flag {
182 *multilinear_subcube_count =
183 (*multilinear_subcube_count).max(evaluator_subcube_count);
184 }
185 }
186 }
187
188 let index_vars = n_vars - 1 - subcube_vars;
189 let packed_accumulators = (0..1 << index_vars)
190 .into_par_iter()
191 .try_fold(
192 || ParFoldStates::new(&access, n_multilinears, n_round_evals.clone(), subcube_vars),
193 |mut par_fold_states, subcube_index| {
194 let ParFoldStates {
195 multilinear_evals,
196 scratch_space,
197 round_evals,
198 } = &mut par_fold_states;
199
200 for (multilinear, evals, &subcube_count) in
201 izip!(multilinears, multilinear_evals.iter_mut(), &subcube_count_by_multilinear)
202 {
203 if subcube_index < subcube_count {
204 access.subcube_evaluations(
205 multilinear,
206 subcube_vars,
207 subcube_index,
208 index_vars,
209 tensor_query,
210 scratch_space.as_deref_mut(),
211 &mut evals.evals_0,
212 &mut evals.evals_1,
213 )?;
214 }
215 }
216
217 for eval_point_index in eval_point_indices.clone() {
219 let is_infinity_point = eval_point_index == 2;
221
222 let evals_z_iter =
231 izip!(multilinear_evals.iter_mut(), &subcube_count_by_multilinear).map(
232 |(evals, &subcube_count)| match eval_point_index {
233 _ if subcube_index >= subcube_count => evals.evals_0.as_slice(),
235 0 => evals.evals_0.as_slice(),
236 1 => evals.evals_1.as_slice(),
237 2 => {
238 izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
240 .for_each(|(eval_z, &eval_0, &eval_1)| {
241 *eval_z = eval_1 - eval_0;
242 });
243
244 evals.evals_z.as_slice()
245 }
246 3.. => {
247 let eval_point =
249 nontrivial_evaluation_points[eval_point_index - 3];
250 let eval_point_broadcast =
251 <PackedSubfield<P, FDomain>>::broadcast(eval_point);
252
253 izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
254 .for_each(|(eval_z, &eval_0, &eval_1)| {
255 *eval_z = P::cast_ext(extrapolate_lines(
260 P::cast_base(eval_0),
261 P::cast_base(eval_1),
262 eval_point_broadcast,
263 ));
264 });
265
266 evals.evals_z.as_slice()
267 }
268 },
269 );
270
271 let row_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
272 stackalloc_with_iter(n_multilinears, evals_z_iter, |evals_z| {
273 let evals_z = RowsBatchRef::new(evals_z, row_len);
274
275 for (evaluator, round_evals, &subcube_count) in
276 izip!(evaluators, round_evals.iter_mut(), &subcube_count_by_evaluator)
277 {
278 let eval_point_indices = evaluator.eval_point_indices();
279 if !eval_point_indices.contains(&eval_point_index)
280 || subcube_index >= subcube_count
281 {
282 continue;
283 }
284
285 round_evals[eval_point_index - eval_point_indices.start] += evaluator
286 .process_subcube_at_eval_point(
287 subcube_vars,
288 subcube_index,
289 is_infinity_point,
290 &evals_z,
291 );
292 }
293 });
294 }
295
296 Ok(par_fold_states)
297 },
298 )
299 .map(|states: Result<ParFoldStates<P>, Error>| -> Result<_, Error> {
300 Ok(states?.round_evals)
301 })
302 .try_reduce(
304 || {
305 evaluators
306 .iter()
307 .map(|evaluator| vec![P::zero(); evaluator.eval_point_indices().len()])
308 .collect()
309 },
310 |lhs, rhs| {
311 let sum = izip!(lhs, rhs)
312 .map(|(mut lhs_vals, rhs_vals)| {
313 for (lhs_val, rhs_val) in lhs_vals.iter_mut().zip(rhs_vals) {
314 *lhs_val += rhs_val;
315 }
316 lhs_vals
317 })
318 .collect();
319 Ok(sum)
320 },
321 )?;
322
323 let round_evals = izip!(packed_accumulators, evaluators, subcube_count_by_evaluator)
324 .map(|(packed_round_evals, evaluator, subcube_count)| {
325 let mut round_evals = packed_round_evals
326 .into_iter()
327 .map(|packed_round_eval| packed_round_eval.iter().take(1 << subcube_vars).sum())
329 .collect::<Vec<F>>();
330
331 let const_eval_suffix = (1 << n_vars) - (subcube_count << subcube_vars);
332 for (eval_point_index, round_eval) in
333 izip!(eval_point_indices.clone(), &mut round_evals)
334 {
335 let is_infinity_point = eval_point_index == 2;
336 *round_eval +=
337 evaluator.process_constant_eval_suffix(const_eval_suffix, is_infinity_point);
338 }
339
340 RoundEvals(round_evals)
341 })
342 .collect();
343
344 Ok(round_evals)
345}
346
347#[derive(Debug)]
349struct MultilinearEvals<P: PackedField> {
350 evals_0: Vec<P>,
351 evals_1: Vec<P>,
352 evals_z: Vec<P>,
353}
354
355impl<P: PackedField> MultilinearEvals<P> {
356 fn new(subcube_vars: usize) -> Self {
357 let len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
358 Self {
359 evals_0: zeroed_vec(len),
360 evals_1: zeroed_vec(len),
361 evals_z: zeroed_vec(len),
362 }
363 }
364}
365
366#[derive(Debug)]
368struct ParFoldStates<P: PackedField> {
369 multilinear_evals: Vec<MultilinearEvals<P>>,
371
372 scratch_space: Option<Vec<P>>,
374
375 round_evals: Vec<Vec<P>>,
380}
381
382impl<P: PackedField> ParFoldStates<P> {
383 fn new(
384 access: &impl SumcheckMultilinearAccess<P>,
385 n_multilinears: usize,
386 n_round_evals: impl Iterator<Item = usize>,
387 subcube_vars: usize,
388 ) -> Self {
389 Self {
390 multilinear_evals: (0..n_multilinears)
391 .map(|_| MultilinearEvals::new(subcube_vars))
392 .collect(),
393 scratch_space: access
394 .scratch_space_len(subcube_vars)
395 .map(|len| zeroed_vec(len)),
396 round_evals: n_round_evals
397 .map(|n_round_evals| zeroed_vec(n_round_evals))
398 .collect(),
399 }
400 }
401}
402
403#[derive(Debug)]
404struct LowToHighAccess;
405
406impl<P: PackedField> SumcheckMultilinearAccess<P> for LowToHighAccess {
407 fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize> {
408 Some(1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH))
410 }
411
412 fn subcube_evaluations<M: MultilinearPoly<P>>(
413 &self,
414 multilinear: &SumcheckMultilinear<P, M>,
415 subcube_vars: usize,
416 subcube_index: usize,
417 _index_vars: usize,
418 tensor_query: MultilinearQueryRef<P>,
419 scratch_space: Option<&mut [P]>,
420 evals_0: &mut [P],
421 evals_1: &mut [P],
422 ) -> Result<(), Error> {
423 let Some(scratch_space) = scratch_space else {
424 bail!(Error::NoScratchSpace);
425 };
426
427 if scratch_space.len() != 1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH)
428 || evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
429 || evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
430 {
431 bail!(Error::IncorrectDestSliceLengths);
432 }
433
434 match multilinear {
435 SumcheckMultilinear::Transparent { multilinear, .. } => {
436 if tensor_query.n_vars() == 0 {
437 multilinear.subcube_evals(subcube_vars + 1, subcube_index, 0, scratch_space)?
438 } else {
439 multilinear.subcube_partial_low_evals(
440 tensor_query,
441 subcube_vars + 1,
442 subcube_index,
443 scratch_space,
444 )?
445 }
446 }
447
448 SumcheckMultilinear::Folded {
449 large_field_folded_evals: evals,
450 } => {
451 if subcube_vars + 1 >= P::LOG_WIDTH {
452 let packed_log_size = subcube_vars + 1 - P::LOG_WIDTH;
453 let offset = subcube_index << packed_log_size;
454 let packed_len = (1 << packed_log_size).min(evals.len().saturating_sub(offset));
455 if packed_len > 0 {
456 scratch_space[..packed_len]
457 .copy_from_slice(&evals[offset..offset + packed_len]);
458 }
459 scratch_space[packed_len..].fill(P::zero());
460 } else {
461 let mut only_packed = P::zero();
462
463 for i in 0..1 << (subcube_vars + 1) {
464 let index = subcube_index << (subcube_vars + 1) | i;
465 only_packed.set(
466 i,
467 get_packed_slice_checked(evals, index).unwrap_or(P::Scalar::ZERO),
468 );
469 }
470
471 *scratch_space.first_mut().expect("non-empty scratch space") = only_packed;
472 }
473 }
474 }
475
476 let zeros = P::default();
479 let interleaved_tuples = if scratch_space.len() == 1 {
480 Either::Left(iter::once((scratch_space.first().expect("len==1"), &zeros)))
481 } else {
482 Either::Right(scratch_space.iter().tuples())
483 };
484
485 for ((&interleaved_0, &interleaved_1), evals_0, evals_1) in
486 izip!(interleaved_tuples, evals_0, evals_1)
487 {
488 let (deinterleaved_0, deinterleaved_1) = if P::LOG_WIDTH > 0 {
489 P::unzip(interleaved_0, interleaved_1, 0)
490 } else {
491 (interleaved_0, interleaved_1)
492 };
493
494 *evals_0 = deinterleaved_0;
495 *evals_1 = deinterleaved_1;
496 }
497
498 Ok(())
499 }
500}
501
502#[derive(Debug)]
503struct HighToLowAccess;
504
505impl<P: PackedField> SumcheckMultilinearAccess<P> for HighToLowAccess {
506 fn scratch_space_len(&self, _subcube_vars: usize) -> Option<usize> {
507 None
508 }
509
510 fn subcube_evaluations<M: MultilinearPoly<P>>(
511 &self,
512 multilinear: &SumcheckMultilinear<P, M>,
513 subcube_vars: usize,
514 subcube_index: usize,
515 index_vars: usize,
516 tensor_query: MultilinearQueryRef<P>,
517 _scratch_space: Option<&mut [P]>,
518 evals_0: &mut [P],
519 evals_1: &mut [P],
520 ) -> Result<(), Error> {
521 if evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
522 || evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
523 {
524 bail!(Error::IncorrectDestSliceLengths);
525 }
526
527 match multilinear {
528 SumcheckMultilinear::Transparent { multilinear, .. } => {
529 if tensor_query.n_vars() == 0 {
530 multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals_0)?;
531 multilinear.subcube_evals(
532 subcube_vars,
533 subcube_index | 1 << index_vars,
534 0,
535 evals_1,
536 )?;
537 } else {
538 multilinear.subcube_partial_high_evals(
539 tensor_query,
540 subcube_vars,
541 subcube_index,
542 evals_0,
543 )?;
544 multilinear.subcube_partial_high_evals(
545 tensor_query,
546 subcube_vars,
547 subcube_index | 1 << index_vars,
548 evals_1,
549 )?;
550 }
551 }
552
553 SumcheckMultilinear::Folded {
554 large_field_folded_evals: evals,
555 } => {
556 if subcube_vars >= P::LOG_WIDTH {
557 let packed_log_size = subcube_vars - P::LOG_WIDTH;
558 let offset_0 = subcube_index << packed_log_size;
559 let offset_1 = offset_0 | 1 << (index_vars + packed_log_size);
560 let packed_len_0 =
561 (1 << packed_log_size).min(evals.len().saturating_sub(offset_0));
562 let packed_len_1 =
563 (1 << packed_log_size).min(evals.len().saturating_sub(offset_1));
564
565 if packed_len_0 > 0 {
566 evals_0[..packed_len_0].copy_from_slice(&evals[offset_0..][..packed_len_0]);
567 }
568
569 if packed_len_1 > 0 {
570 evals_1[..packed_len_1].copy_from_slice(&evals[offset_1..][..packed_len_1]);
571 }
572
573 evals_0[packed_len_0..].fill(P::zero());
574 evals_1[packed_len_1..].fill(P::zero());
575 } else {
576 let mut evals_0_packed = P::zero();
577 let mut evals_1_packed = P::zero();
578
579 for i in 0..1 << subcube_vars {
580 let index_0 = subcube_index << subcube_vars | i;
581 let index_1 = index_0 | 1 << (index_vars + subcube_vars);
582 evals_0_packed.set(
583 i,
584 get_packed_slice_checked(evals, index_0).unwrap_or(P::Scalar::ZERO),
585 );
586 evals_1_packed.set(
587 i,
588 get_packed_slice_checked(evals, index_1).unwrap_or(P::Scalar::ZERO),
589 );
590 }
591
592 *evals_0.first_mut().expect("non-empty evals_0") = evals_0_packed;
593 *evals_1.first_mut().expect("non-empty evals_1") = evals_1_packed;
594 }
595 }
596 }
597
598 Ok(())
599 }
600}