1use std::iter;
8
9use binius_field::{
10 Field, PackedExtension, PackedField, PackedSubfield, packed::get_packed_slice_checked,
11};
12use binius_math::{
13 CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery, MultilinearQueryRef,
14 RowsBatchRef, extrapolate_lines,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{Either, Itertools, izip};
20use stackalloc::stackalloc_with_iter;
21
22use crate::{
23 Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
24 common::{MAX_SRC_SUBCUBE_LOG_BITS, subcube_vars_for_bits},
25};
26
27trait SumcheckMultilinearAccess<P: PackedField> {
28 fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize>;
30
31 #[allow(clippy::too_many_arguments)]
68 fn subcube_evaluations<M: MultilinearPoly<P>>(
69 &self,
70 multilinear: &SumcheckMultilinear<P, M>,
71 subcube_vars: usize,
72 subcube_index: usize,
73 index_vars: usize,
74 tensor_query: MultilinearQueryRef<P>,
75 scratch_space: Option<&mut [P]>,
76 evals_0: &mut [P],
77 evals_1: &mut [P],
78 ) -> Result<(), Error>;
79}
80
81pub(crate) fn calculate_round_evals<FDomain, F, P, M, Evaluator, Composition>(
86 evaluation_order: EvaluationOrder,
87 n_vars: usize,
88 tensor_query: Option<MultilinearQueryRef<P>>,
89 multilinears: &[SumcheckMultilinear<P, M>],
90 evaluators: &[Evaluator],
91 finite_evaluation_points: &[FDomain],
92) -> Result<Vec<RoundEvals<F>>, Error>
93where
94 FDomain: Field,
95 F: Field,
96 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
97 M: MultilinearPoly<P> + Sync,
98 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
99 Composition: CompositionPoly<P>,
100{
101 assert!(n_vars > 0, "Computing round evaluations requires at least a single variable.");
102
103 let empty_query = MultilinearQuery::with_capacity(0);
104 let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref());
105
106 match evaluation_order {
107 EvaluationOrder::LowToHigh => calculate_round_evals_with_access(
108 LowToHighAccess,
109 n_vars,
110 tensor_query,
111 multilinears,
112 evaluators,
113 finite_evaluation_points,
114 ),
115 EvaluationOrder::HighToLow => calculate_round_evals_with_access(
116 HighToLowAccess,
117 n_vars,
118 tensor_query,
119 multilinears,
120 evaluators,
121 finite_evaluation_points,
122 ),
123 }
124}
125
126fn calculate_round_evals_with_access<FDomain, F, P, M, Evaluator, Access, Composition>(
127 access: Access,
128 n_vars: usize,
129 tensor_query: MultilinearQueryRef<P>,
130 multilinears: &[SumcheckMultilinear<P, M>],
131 evaluators: &[Evaluator],
132 nontrivial_evaluation_points: &[FDomain],
133) -> Result<Vec<RoundEvals<F>>, Error>
134where
135 FDomain: Field,
136 F: Field,
137 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
138 M: MultilinearPoly<P> + Sync,
139 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
140 Access: SumcheckMultilinearAccess<P> + Sync,
141 Composition: CompositionPoly<P>,
142{
143 let n_multilinears = multilinears.len();
144 let n_round_evals = evaluators
145 .iter()
146 .map(|evaluator| evaluator.eval_point_indices().len());
147
148 let eval_point_indices = evaluators
150 .iter()
151 .map(|evaluator| evaluator.eval_point_indices())
152 .reduce(|range1, range2| range1.start.min(range2.start)..range1.end.max(range2.end))
153 .unwrap_or(0..0);
154
155 if nontrivial_evaluation_points.len() != eval_point_indices.end.saturating_sub(3) {
158 bail!(Error::IncorrectNontrivialEvalPointsLength);
159 }
160
161 let subcube_vars = subcube_vars_for_bits::<P>(
164 MAX_SRC_SUBCUBE_LOG_BITS,
165 n_vars - 1,
166 tensor_query.n_vars(),
167 n_vars - 1,
168 );
169
170 let subcube_count_by_evaluator = evaluators
171 .iter()
172 .map(|evaluator| {
173 ((1 << (n_vars - 1)) - evaluator.const_eval_suffix()).div_ceil(1 << subcube_vars)
174 })
175 .collect::<Vec<_>>();
176
177 let mut subcube_count_by_multilinear = vec![0; n_multilinears];
178
179 for (&evaluator_subcube_count, evaluator) in izip!(&subcube_count_by_evaluator, evaluators) {
180 let used_vars = evaluator.composition().expression().vars_usage();
181
182 for (multilinear_subcube_count, usage_flag) in
183 izip!(&mut subcube_count_by_multilinear, used_vars)
184 {
185 if usage_flag {
186 *multilinear_subcube_count =
187 (*multilinear_subcube_count).max(evaluator_subcube_count);
188 }
189 }
190 }
191
192 let index_vars = n_vars - 1 - subcube_vars;
193 let packed_accumulators = (0..1 << index_vars)
194 .into_par_iter()
195 .try_fold(
196 || ParFoldStates::new(&access, n_multilinears, n_round_evals.clone(), subcube_vars),
197 |mut par_fold_states, subcube_index| {
198 let ParFoldStates {
199 multilinear_evals,
200 scratch_space,
201 round_evals,
202 } = &mut par_fold_states;
203
204 for (multilinear, evals, &subcube_count) in
205 izip!(multilinears, multilinear_evals.iter_mut(), &subcube_count_by_multilinear)
206 {
207 if subcube_index < subcube_count {
208 access.subcube_evaluations(
209 multilinear,
210 subcube_vars,
211 subcube_index,
212 index_vars,
213 tensor_query,
214 scratch_space.as_deref_mut(),
215 &mut evals.evals_0,
216 &mut evals.evals_1,
217 )?;
218 }
219 }
220
221 for eval_point_index in eval_point_indices.clone() {
223 let is_infinity_point = eval_point_index == 2;
225
226 let evals_z_iter =
235 izip!(multilinear_evals.iter_mut(), &subcube_count_by_multilinear).map(
236 |(evals, &subcube_count)| match eval_point_index {
237 _ if subcube_index >= subcube_count => evals.evals_0.as_slice(),
239 0 => evals.evals_0.as_slice(),
240 1 => evals.evals_1.as_slice(),
241 2 => {
242 izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
244 .for_each(|(eval_z, &eval_0, &eval_1)| {
245 *eval_z = eval_1 - eval_0;
246 });
247
248 evals.evals_z.as_slice()
249 }
250 3.. => {
251 let eval_point =
253 nontrivial_evaluation_points[eval_point_index - 3];
254 let eval_point_broadcast =
255 <PackedSubfield<P, FDomain>>::broadcast(eval_point);
256
257 izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
258 .for_each(|(eval_z, &eval_0, &eval_1)| {
259 *eval_z = P::cast_ext(extrapolate_lines(
264 P::cast_base(eval_0),
265 P::cast_base(eval_1),
266 eval_point_broadcast,
267 ));
268 });
269
270 evals.evals_z.as_slice()
271 }
272 },
273 );
274
275 let row_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
276 stackalloc_with_iter(n_multilinears, evals_z_iter, |evals_z| {
277 let evals_z = RowsBatchRef::new(evals_z, row_len);
278
279 for (evaluator, round_evals, &subcube_count) in
280 izip!(evaluators, round_evals.iter_mut(), &subcube_count_by_evaluator)
281 {
282 let eval_point_indices = evaluator.eval_point_indices();
283 if !eval_point_indices.contains(&eval_point_index)
284 || subcube_index >= subcube_count
285 {
286 continue;
287 }
288
289 round_evals[eval_point_index - eval_point_indices.start] += evaluator
290 .process_subcube_at_eval_point(
291 subcube_vars,
292 subcube_index,
293 is_infinity_point,
294 &evals_z,
295 );
296 }
297 });
298 }
299
300 Ok(par_fold_states)
301 },
302 )
303 .map(|states: Result<ParFoldStates<P>, Error>| -> Result<_, Error> {
304 Ok(states?.round_evals)
305 })
306 .try_reduce(
308 || {
309 evaluators
310 .iter()
311 .map(|evaluator| vec![P::zero(); evaluator.eval_point_indices().len()])
312 .collect()
313 },
314 |lhs, rhs| {
315 let sum = izip!(lhs, rhs)
316 .map(|(mut lhs_vals, rhs_vals)| {
317 for (lhs_val, rhs_val) in lhs_vals.iter_mut().zip(rhs_vals) {
318 *lhs_val += rhs_val;
319 }
320 lhs_vals
321 })
322 .collect();
323 Ok(sum)
324 },
325 )?;
326
327 let round_evals = izip!(packed_accumulators, evaluators, subcube_count_by_evaluator)
328 .map(|(packed_round_evals, evaluator, subcube_count)| {
329 let mut round_evals = packed_round_evals
330 .into_iter()
331 .map(|packed_round_eval| packed_round_eval.iter().take(1 << subcube_vars).sum())
333 .collect::<Vec<F>>();
334
335 let const_eval_suffix = (1 << n_vars) - (subcube_count << subcube_vars);
336 for (eval_point_index, round_eval) in
337 izip!(eval_point_indices.clone(), &mut round_evals)
338 {
339 let is_infinity_point = eval_point_index == 2;
340 *round_eval +=
341 evaluator.process_constant_eval_suffix(const_eval_suffix, is_infinity_point);
342 }
343
344 RoundEvals(round_evals)
345 })
346 .collect();
347
348 Ok(round_evals)
349}
350
351#[derive(Debug)]
353struct MultilinearEvals<P: PackedField> {
354 evals_0: Vec<P>,
355 evals_1: Vec<P>,
356 evals_z: Vec<P>,
357}
358
359impl<P: PackedField> MultilinearEvals<P> {
360 fn new(subcube_vars: usize) -> Self {
361 let len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
362 Self {
363 evals_0: zeroed_vec(len),
364 evals_1: zeroed_vec(len),
365 evals_z: zeroed_vec(len),
366 }
367 }
368}
369
370#[derive(Debug)]
372struct ParFoldStates<P: PackedField> {
373 multilinear_evals: Vec<MultilinearEvals<P>>,
375
376 scratch_space: Option<Vec<P>>,
378
379 round_evals: Vec<Vec<P>>,
384}
385
386impl<P: PackedField> ParFoldStates<P> {
387 fn new(
388 access: &impl SumcheckMultilinearAccess<P>,
389 n_multilinears: usize,
390 n_round_evals: impl Iterator<Item = usize>,
391 subcube_vars: usize,
392 ) -> Self {
393 Self {
394 multilinear_evals: (0..n_multilinears)
395 .map(|_| MultilinearEvals::new(subcube_vars))
396 .collect(),
397 scratch_space: access
398 .scratch_space_len(subcube_vars)
399 .map(|len| zeroed_vec(len)),
400 round_evals: n_round_evals
401 .map(|n_round_evals| zeroed_vec(n_round_evals))
402 .collect(),
403 }
404 }
405}
406
407#[derive(Debug)]
408struct LowToHighAccess;
409
410impl<P: PackedField> SumcheckMultilinearAccess<P> for LowToHighAccess {
411 fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize> {
412 Some(1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH))
414 }
415
416 fn subcube_evaluations<M: MultilinearPoly<P>>(
417 &self,
418 multilinear: &SumcheckMultilinear<P, M>,
419 subcube_vars: usize,
420 subcube_index: usize,
421 _index_vars: usize,
422 tensor_query: MultilinearQueryRef<P>,
423 scratch_space: Option<&mut [P]>,
424 evals_0: &mut [P],
425 evals_1: &mut [P],
426 ) -> Result<(), Error> {
427 let Some(scratch_space) = scratch_space else {
428 bail!(Error::NoScratchSpace);
429 };
430
431 if scratch_space.len() != 1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH)
432 || evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
433 || evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
434 {
435 bail!(Error::IncorrectDestSliceLengths);
436 }
437
438 match multilinear {
439 SumcheckMultilinear::Transparent { multilinear, .. } => {
440 if tensor_query.n_vars() == 0 {
441 multilinear.subcube_evals(subcube_vars + 1, subcube_index, 0, scratch_space)?
442 } else {
443 multilinear.subcube_partial_low_evals(
444 tensor_query,
445 subcube_vars + 1,
446 subcube_index,
447 scratch_space,
448 )?
449 }
450 }
451
452 SumcheckMultilinear::Folded {
453 large_field_folded_evals: evals,
454 suffix_eval,
455 } => {
456 if subcube_vars + 1 >= P::LOG_WIDTH {
457 let packed_log_size = subcube_vars + 1 - P::LOG_WIDTH;
458 let offset = subcube_index << packed_log_size;
459 let packed_len = (1 << packed_log_size).min(evals.len().saturating_sub(offset));
460 if packed_len > 0 {
461 scratch_space[..packed_len]
462 .copy_from_slice(&evals[offset..offset + packed_len]);
463 }
464 scratch_space[packed_len..].fill(P::broadcast(*suffix_eval));
465 } else {
466 let mut only_packed = P::zero();
467
468 for i in 0..1 << (subcube_vars + 1) {
469 let index = subcube_index << (subcube_vars + 1) | i;
470 only_packed
471 .set(i, get_packed_slice_checked(evals, index).unwrap_or(*suffix_eval));
472 }
473
474 *scratch_space.first_mut().expect("non-empty scratch space") = only_packed;
475 }
476 }
477 }
478
479 let zeros = P::default();
483 let interleaved_tuples = if scratch_space.len() == 1 {
484 Either::Left(iter::once((scratch_space.first().expect("len==1"), &zeros)))
485 } else {
486 Either::Right(scratch_space.iter().tuples())
487 };
488
489 for ((&interleaved_0, &interleaved_1), evals_0, evals_1) in
490 izip!(interleaved_tuples, evals_0, evals_1)
491 {
492 let (deinterleaved_0, deinterleaved_1) = if P::LOG_WIDTH > 0 {
493 P::unzip(interleaved_0, interleaved_1, 0)
494 } else {
495 (interleaved_0, interleaved_1)
496 };
497
498 *evals_0 = deinterleaved_0;
499 *evals_1 = deinterleaved_1;
500 }
501
502 Ok(())
503 }
504}
505
506#[derive(Debug)]
507struct HighToLowAccess;
508
509impl<P: PackedField> SumcheckMultilinearAccess<P> for HighToLowAccess {
510 fn scratch_space_len(&self, _subcube_vars: usize) -> Option<usize> {
511 None
512 }
513
514 fn subcube_evaluations<M: MultilinearPoly<P>>(
515 &self,
516 multilinear: &SumcheckMultilinear<P, M>,
517 subcube_vars: usize,
518 subcube_index: usize,
519 index_vars: usize,
520 tensor_query: MultilinearQueryRef<P>,
521 _scratch_space: Option<&mut [P]>,
522 evals_0: &mut [P],
523 evals_1: &mut [P],
524 ) -> Result<(), Error> {
525 if evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
526 || evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
527 {
528 bail!(Error::IncorrectDestSliceLengths);
529 }
530
531 match multilinear {
532 SumcheckMultilinear::Transparent { multilinear, .. } => {
533 if tensor_query.n_vars() == 0 {
534 multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals_0)?;
535 multilinear.subcube_evals(
536 subcube_vars,
537 subcube_index | 1 << index_vars,
538 0,
539 evals_1,
540 )?;
541 } else {
542 multilinear.subcube_partial_high_evals(
543 tensor_query,
544 subcube_vars,
545 subcube_index,
546 evals_0,
547 )?;
548 multilinear.subcube_partial_high_evals(
549 tensor_query,
550 subcube_vars,
551 subcube_index | 1 << index_vars,
552 evals_1,
553 )?;
554 }
555 }
556
557 SumcheckMultilinear::Folded {
558 large_field_folded_evals: evals,
559 suffix_eval,
560 } => {
561 if subcube_vars >= P::LOG_WIDTH {
562 let packed_log_size = subcube_vars - P::LOG_WIDTH;
563 let offset_0 = subcube_index << packed_log_size;
564 let offset_1 = offset_0 | 1 << (index_vars + packed_log_size);
565 let packed_len_0 =
566 (1 << packed_log_size).min(evals.len().saturating_sub(offset_0));
567 let packed_len_1 =
568 (1 << packed_log_size).min(evals.len().saturating_sub(offset_1));
569
570 if packed_len_0 > 0 {
571 evals_0[..packed_len_0].copy_from_slice(&evals[offset_0..][..packed_len_0]);
572 }
573
574 if packed_len_1 > 0 {
575 evals_1[..packed_len_1].copy_from_slice(&evals[offset_1..][..packed_len_1]);
576 }
577
578 evals_0[packed_len_0..].fill(P::broadcast(*suffix_eval));
579 evals_1[packed_len_1..].fill(P::broadcast(*suffix_eval));
580 } else {
581 let mut evals_0_packed = P::zero();
582 let mut evals_1_packed = P::zero();
583
584 for i in 0..1 << subcube_vars {
585 let index_0 = subcube_index << subcube_vars | i;
586 let index_1 = index_0 | 1 << (index_vars + subcube_vars);
587 evals_0_packed.set(
588 i,
589 get_packed_slice_checked(evals, index_0).unwrap_or(*suffix_eval),
590 );
591 evals_1_packed.set(
592 i,
593 get_packed_slice_checked(evals, index_1).unwrap_or(*suffix_eval),
594 );
595 }
596
597 *evals_0.first_mut().expect("non-empty evals_0") = evals_0_packed;
598 *evals_1.first_mut().expect("non-empty evals_1") = evals_1_packed;
599 }
600 }
601 }
602
603 Ok(())
604 }
605}