binius_hal/
sumcheck_round_calculation.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! Functions that calculate the sumcheck round evaluations.
4//!
5//! This is one of the core computational tasks in the sumcheck proving algorithm.
6
7use std::iter;
8
9use binius_field::{
10	Field, PackedExtension, PackedField, PackedSubfield, packed::get_packed_slice_checked,
11};
12use binius_math::{
13	CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery, MultilinearQueryRef,
14	RowsBatchRef, extrapolate_lines,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{Either, Itertools, izip};
20use stackalloc::stackalloc_with_iter;
21
22use crate::{
23	Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
24	common::{MAX_SRC_SUBCUBE_LOG_BITS, subcube_vars_for_bits},
25};
26
27trait SumcheckMultilinearAccess<P: PackedField> {
28	/// The size of `Vec<P>` scratchspace used by [`subcube_evaluations`], if any.
29	fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize>;
30
31	/// A way to obtain multilinear evaluations during sumcheck.
32	///
33	/// Sumcheck is conducted over boolean hypercubes represented in little endian form
34	/// (faster strides correspond to lower variable indexes). Evaluation order may proceed both
35	/// from lowest variable to highest, as well as in reverse. The $n$-dimensional evaluation
36	/// hypercube can be split into subcubes of `subcube_vars` by substituting higher variables
37	/// for the little endian binary representation of some index.
38	///
39	/// Assume $r$ round challenges have been sampled already. A sumcheck multilinear then
40	/// is either an $n + r$-variate transparent where $r$ of its variables (lowest or highest)
41	/// are to be projected onto round challenges, or $n$-variate folded multilinear where this
42	/// projection has already taken place. It is further split into two $n-1$ variate subcubes
43	/// by substituting 0 and 1 for its lowest or highest variable, depending on evaluation order.
44	///
45	/// Assuming `subcube_vars + index_vars = n-1` holds, we substitute the binary representation
46	/// of `subcube_index` into higher indexed variables. Note that this sub-subcube ordering
47	/// _does not_ depend on evaluation order.
48	///
49	/// Indexed subcube evaluations are written into `evals_0` and `evals_1` slices, where scalar
50	/// order corresponds to the lower `P::LOG_WIDTH` variables of the `subcube_vars`-variate
51	/// hypercube.
52	///
53	/// The method can potentially require a `&mut [P]` scratch space, whose length is given by a
54	/// query to [`scratch_space_len`] and should be uniquely determined by `subcube_vars`.
55	///
56	/// ## Arguments
57	///
58	/// * `subcube_vars`  - the number of variables in the sub-subcube to evaluate over
59	/// * `subcube_index` - the index of the subcube within the $n-1$-variate hypercube
60	/// * `index_vars`    - number of bits in the `subcube_index`
61	/// * `tensor_query`  - multilinear query of pre-switchover challenges (empty if all folded)
62	/// * `scratch_space` - optional scratch space
63	/// * `evals_0`       - `subcube_vars`-variate hypercube with current variables substituted for
64	///   0
65	/// * `evals_1`       - `subcube_vars`-variate hypercube with current variables substituted for
66	///   1
67	#[allow(clippy::too_many_arguments)]
68	fn subcube_evaluations<M: MultilinearPoly<P>>(
69		&self,
70		multilinear: &SumcheckMultilinear<P, M>,
71		subcube_vars: usize,
72		subcube_index: usize,
73		index_vars: usize,
74		tensor_query: MultilinearQueryRef<P>,
75		scratch_space: Option<&mut [P]>,
76		evals_0: &mut [P],
77		evals_1: &mut [P],
78	) -> Result<(), Error>;
79}
80
81/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
82///
83/// See [`calculate_first_round_evals`] for an optimized version of this method
84/// that works over small fields in the first round.
85pub(crate) fn calculate_round_evals<FDomain, F, P, M, Evaluator, Composition>(
86	evaluation_order: EvaluationOrder,
87	n_vars: usize,
88	tensor_query: Option<MultilinearQueryRef<P>>,
89	multilinears: &[SumcheckMultilinear<P, M>],
90	evaluators: &[Evaluator],
91	finite_evaluation_points: &[FDomain],
92) -> Result<Vec<RoundEvals<F>>, Error>
93where
94	FDomain: Field,
95	F: Field,
96	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
97	M: MultilinearPoly<P> + Sync,
98	Evaluator: SumcheckEvaluator<P, Composition> + Sync,
99	Composition: CompositionPoly<P>,
100{
101	assert!(n_vars > 0, "Computing round evaluations requires at least a single variable.");
102
103	let empty_query = MultilinearQuery::with_capacity(0);
104	let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref());
105
106	match evaluation_order {
107		EvaluationOrder::LowToHigh => calculate_round_evals_with_access(
108			LowToHighAccess,
109			n_vars,
110			tensor_query,
111			multilinears,
112			evaluators,
113			finite_evaluation_points,
114		),
115		EvaluationOrder::HighToLow => calculate_round_evals_with_access(
116			HighToLowAccess,
117			n_vars,
118			tensor_query,
119			multilinears,
120			evaluators,
121			finite_evaluation_points,
122		),
123	}
124}
125
126fn calculate_round_evals_with_access<FDomain, F, P, M, Evaluator, Access, Composition>(
127	access: Access,
128	n_vars: usize,
129	tensor_query: MultilinearQueryRef<P>,
130	multilinears: &[SumcheckMultilinear<P, M>],
131	evaluators: &[Evaluator],
132	nontrivial_evaluation_points: &[FDomain],
133) -> Result<Vec<RoundEvals<F>>, Error>
134where
135	FDomain: Field,
136	F: Field,
137	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
138	M: MultilinearPoly<P> + Sync,
139	Evaluator: SumcheckEvaluator<P, Composition> + Sync,
140	Access: SumcheckMultilinearAccess<P> + Sync,
141	Composition: CompositionPoly<P>,
142{
143	let n_multilinears = multilinears.len();
144	let n_round_evals = evaluators
145		.iter()
146		.map(|evaluator| evaluator.eval_point_indices().len());
147
148	// Compute the union of all evaluation point index ranges.
149	let eval_point_indices = evaluators
150		.iter()
151		.map(|evaluator| evaluator.eval_point_indices())
152		.reduce(|range1, range2| range1.start.min(range2.start)..range1.end.max(range2.end))
153		.unwrap_or(0..0);
154
155	// Check that finite evaluation points  are of correct length (accounted for 0, 1 & infinity
156	// point).
157	if nontrivial_evaluation_points.len() != eval_point_indices.end.saturating_sub(3) {
158		bail!(Error::IncorrectNontrivialEvalPointsLength);
159	}
160
161	// Here we assume that at least one multilinear would be "full"
162	// REVIEW: come up with a better heuristic
163	let subcube_vars = subcube_vars_for_bits::<P>(
164		MAX_SRC_SUBCUBE_LOG_BITS,
165		n_vars - 1,
166		tensor_query.n_vars(),
167		n_vars - 1,
168	);
169
170	let subcube_count_by_evaluator = evaluators
171		.iter()
172		.map(|evaluator| {
173			((1 << (n_vars - 1)) - evaluator.const_eval_suffix()).div_ceil(1 << subcube_vars)
174		})
175		.collect::<Vec<_>>();
176
177	let mut subcube_count_by_multilinear = vec![0; n_multilinears];
178
179	for (&evaluator_subcube_count, evaluator) in izip!(&subcube_count_by_evaluator, evaluators) {
180		let used_vars = evaluator.composition().expression().vars_usage();
181
182		for (multilinear_subcube_count, usage_flag) in
183			izip!(&mut subcube_count_by_multilinear, used_vars)
184		{
185			if usage_flag {
186				*multilinear_subcube_count =
187					(*multilinear_subcube_count).max(evaluator_subcube_count);
188			}
189		}
190	}
191
192	let index_vars = n_vars - 1 - subcube_vars;
193	let packed_accumulators = (0..1 << index_vars)
194		.into_par_iter()
195		.try_fold(
196			|| ParFoldStates::new(&access, n_multilinears, n_round_evals.clone(), subcube_vars),
197			|mut par_fold_states, subcube_index| {
198				let ParFoldStates {
199					multilinear_evals,
200					scratch_space,
201					round_evals,
202				} = &mut par_fold_states;
203
204				for (multilinear, evals, &subcube_count) in
205					izip!(multilinears, multilinear_evals.iter_mut(), &subcube_count_by_multilinear)
206				{
207					if subcube_index < subcube_count {
208						access.subcube_evaluations(
209							multilinear,
210							subcube_vars,
211							subcube_index,
212							index_vars,
213							tensor_query,
214							scratch_space.as_deref_mut(),
215							&mut evals.evals_0,
216							&mut evals.evals_1,
217						)?;
218					}
219				}
220
221				// Proceed by evaluation point first to share interpolation work between evaluators.
222				for eval_point_index in eval_point_indices.clone() {
223					// Infinity point requires special evaluation rules
224					let is_infinity_point = eval_point_index == 2;
225
226					// Multilinears are evaluated at a point t via linear interpolation:
227					//   f(z, xs) = f(0, xs) + z * (f(1, xs) - f(0, xs))
228					// The first three points are treated specially:
229					//   index 0 - z = 0   => f(z, xs) = f(0, xs)
230					//   index 1 - z = 1   => f(z, xs) = f(1, xs)
231					//   index 2 = z = inf => f(inf, xs) = high (f(0, xs) + z * (f(1, xs) - f(0,
232					// xs))) =                                   = f(1, xs) - f(0, xs)
233					//   index 3 and above - remaining finite evaluation points
234					let evals_z_iter =
235						izip!(multilinear_evals.iter_mut(), &subcube_count_by_multilinear).map(
236							|(evals, &subcube_count)| match eval_point_index {
237								// This multilinear is not accessed, return arbitrary slice
238								_ if subcube_index >= subcube_count => evals.evals_0.as_slice(),
239								0 => evals.evals_0.as_slice(),
240								1 => evals.evals_1.as_slice(),
241								2 => {
242									// infinity point
243									izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
244										.for_each(|(eval_z, &eval_0, &eval_1)| {
245											*eval_z = eval_1 - eval_0;
246										});
247
248									evals.evals_z.as_slice()
249								}
250								3.. => {
251									// Account for the gap occupied by the 0, 1 & infinity point
252									let eval_point =
253										nontrivial_evaluation_points[eval_point_index - 3];
254									let eval_point_broadcast =
255										<PackedSubfield<P, FDomain>>::broadcast(eval_point);
256
257									izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
258										.for_each(|(eval_z, &eval_0, &eval_1)| {
259											// This is logically the same as calling
260											// `binius_math::univariate::extrapolate_line`, except
261											// that we do not repeat the broadcast of the
262											// subfield element to a packed subfield.
263											*eval_z = P::cast_ext(extrapolate_lines(
264												P::cast_base(eval_0),
265												P::cast_base(eval_1),
266												eval_point_broadcast,
267											));
268										});
269
270									evals.evals_z.as_slice()
271								}
272							},
273						);
274
275					let row_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
276					stackalloc_with_iter(n_multilinears, evals_z_iter, |evals_z| {
277						let evals_z = RowsBatchRef::new(evals_z, row_len);
278
279						for (evaluator, round_evals, &subcube_count) in
280							izip!(evaluators, round_evals.iter_mut(), &subcube_count_by_evaluator)
281						{
282							let eval_point_indices = evaluator.eval_point_indices();
283							if !eval_point_indices.contains(&eval_point_index)
284								|| subcube_index >= subcube_count
285							{
286								continue;
287							}
288
289							round_evals[eval_point_index - eval_point_indices.start] += evaluator
290								.process_subcube_at_eval_point(
291									subcube_vars,
292									subcube_index,
293									is_infinity_point,
294									&evals_z,
295								);
296						}
297					});
298				}
299
300				Ok(par_fold_states)
301			},
302		)
303		.map(|states: Result<ParFoldStates<P>, Error>| -> Result<_, Error> {
304			Ok(states?.round_evals)
305		})
306		// Simply sum up the fold partitions.
307		.try_reduce(
308			|| {
309				evaluators
310					.iter()
311					.map(|evaluator| vec![P::zero(); evaluator.eval_point_indices().len()])
312					.collect()
313			},
314			|lhs, rhs| {
315				let sum = izip!(lhs, rhs)
316					.map(|(mut lhs_vals, rhs_vals)| {
317						for (lhs_val, rhs_val) in lhs_vals.iter_mut().zip(rhs_vals) {
318							*lhs_val += rhs_val;
319						}
320						lhs_vals
321					})
322					.collect();
323				Ok(sum)
324			},
325		)?;
326
327	let round_evals = izip!(packed_accumulators, evaluators, subcube_count_by_evaluator)
328		.map(|(packed_round_evals, evaluator, subcube_count)| {
329			let mut round_evals = packed_round_evals
330				.into_iter()
331				// Truncate subcubes smaller than packing width.
332				.map(|packed_round_eval| packed_round_eval.iter().take(1 << subcube_vars).sum())
333				.collect::<Vec<F>>();
334
335			let const_eval_suffix = (1 << n_vars) - (subcube_count << subcube_vars);
336			for (eval_point_index, round_eval) in
337				izip!(eval_point_indices.clone(), &mut round_evals)
338			{
339				let is_infinity_point = eval_point_index == 2;
340				*round_eval +=
341					evaluator.process_constant_eval_suffix(const_eval_suffix, is_infinity_point);
342			}
343
344			RoundEvals(round_evals)
345		})
346		.collect();
347
348	Ok(round_evals)
349}
350
351// Evals of a single multilinear over a subcube, at 0/1 and some interpolated point.
352#[derive(Debug)]
353struct MultilinearEvals<P: PackedField> {
354	evals_0: Vec<P>,
355	evals_1: Vec<P>,
356	evals_z: Vec<P>,
357}
358
359impl<P: PackedField> MultilinearEvals<P> {
360	fn new(subcube_vars: usize) -> Self {
361		let len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
362		Self {
363			evals_0: zeroed_vec(len),
364			evals_1: zeroed_vec(len),
365			evals_z: zeroed_vec(len),
366		}
367	}
368}
369
370/// Parallel fold state, consisting of scratch area and result accumulator.
371#[derive(Debug)]
372struct ParFoldStates<P: PackedField> {
373	// Evaluations at 0, 1 and domain points, per MLE. Scratch space.
374	multilinear_evals: Vec<MultilinearEvals<P>>,
375
376	// Additional scratch space.
377	scratch_space: Option<Vec<P>>,
378
379	// Accumulated sums of evaluations over univariate domain.
380	//
381	// Each element of the outer vector corresponds to one composite polynomial. Each element of
382	// an inner vector contains the evaluations at different points.
383	round_evals: Vec<Vec<P>>,
384}
385
386impl<P: PackedField> ParFoldStates<P> {
387	fn new(
388		access: &impl SumcheckMultilinearAccess<P>,
389		n_multilinears: usize,
390		n_round_evals: impl Iterator<Item = usize>,
391		subcube_vars: usize,
392	) -> Self {
393		Self {
394			multilinear_evals: (0..n_multilinears)
395				.map(|_| MultilinearEvals::new(subcube_vars))
396				.collect(),
397			scratch_space: access
398				.scratch_space_len(subcube_vars)
399				.map(|len| zeroed_vec(len)),
400			round_evals: n_round_evals
401				.map(|n_round_evals| zeroed_vec(n_round_evals))
402				.collect(),
403		}
404	}
405}
406
407#[derive(Debug)]
408struct LowToHighAccess;
409
410impl<P: PackedField> SumcheckMultilinearAccess<P> for LowToHighAccess {
411	fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize> {
412		// We need to sample evaluations at both 0 & 1 prior to deinterleaving, thus +1.
413		Some(1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH))
414	}
415
416	fn subcube_evaluations<M: MultilinearPoly<P>>(
417		&self,
418		multilinear: &SumcheckMultilinear<P, M>,
419		subcube_vars: usize,
420		subcube_index: usize,
421		_index_vars: usize,
422		tensor_query: MultilinearQueryRef<P>,
423		scratch_space: Option<&mut [P]>,
424		evals_0: &mut [P],
425		evals_1: &mut [P],
426	) -> Result<(), Error> {
427		let Some(scratch_space) = scratch_space else {
428			bail!(Error::NoScratchSpace);
429		};
430
431		if scratch_space.len() != 1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH)
432			|| evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
433			|| evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
434		{
435			bail!(Error::IncorrectDestSliceLengths);
436		}
437
438		match multilinear {
439			SumcheckMultilinear::Transparent { multilinear, .. } => {
440				if tensor_query.n_vars() == 0 {
441					multilinear.subcube_evals(subcube_vars + 1, subcube_index, 0, scratch_space)?
442				} else {
443					multilinear.subcube_partial_low_evals(
444						tensor_query,
445						subcube_vars + 1,
446						subcube_index,
447						scratch_space,
448					)?
449				}
450			}
451
452			SumcheckMultilinear::Folded {
453				large_field_folded_evals: evals,
454				suffix_eval,
455			} => {
456				if subcube_vars + 1 >= P::LOG_WIDTH {
457					let packed_log_size = subcube_vars + 1 - P::LOG_WIDTH;
458					let offset = subcube_index << packed_log_size;
459					let packed_len = (1 << packed_log_size).min(evals.len().saturating_sub(offset));
460					if packed_len > 0 {
461						scratch_space[..packed_len]
462							.copy_from_slice(&evals[offset..offset + packed_len]);
463					}
464					scratch_space[packed_len..].fill(P::broadcast(*suffix_eval));
465				} else {
466					let mut only_packed = P::zero();
467
468					for i in 0..1 << (subcube_vars + 1) {
469						let index = subcube_index << (subcube_vars + 1) | i;
470						only_packed
471							.set(i, get_packed_slice_checked(evals, index).unwrap_or(*suffix_eval));
472					}
473
474					*scratch_space.first_mut().expect("non-empty scratch space") = only_packed;
475				}
476			}
477		}
478
479		// Evaluations at 0 & 1 are interleaved (the substituted variable is the lowest one), need
480		// to deinterleave them first. This requires scratch space to enable simple linear time
481		// algorithm.
482		let zeros = P::default();
483		let interleaved_tuples = if scratch_space.len() == 1 {
484			Either::Left(iter::once((scratch_space.first().expect("len==1"), &zeros)))
485		} else {
486			Either::Right(scratch_space.iter().tuples())
487		};
488
489		for ((&interleaved_0, &interleaved_1), evals_0, evals_1) in
490			izip!(interleaved_tuples, evals_0, evals_1)
491		{
492			let (deinterleaved_0, deinterleaved_1) = if P::LOG_WIDTH > 0 {
493				P::unzip(interleaved_0, interleaved_1, 0)
494			} else {
495				(interleaved_0, interleaved_1)
496			};
497
498			*evals_0 = deinterleaved_0;
499			*evals_1 = deinterleaved_1;
500		}
501
502		Ok(())
503	}
504}
505
506#[derive(Debug)]
507struct HighToLowAccess;
508
509impl<P: PackedField> SumcheckMultilinearAccess<P> for HighToLowAccess {
510	fn scratch_space_len(&self, _subcube_vars: usize) -> Option<usize> {
511		None
512	}
513
514	fn subcube_evaluations<M: MultilinearPoly<P>>(
515		&self,
516		multilinear: &SumcheckMultilinear<P, M>,
517		subcube_vars: usize,
518		subcube_index: usize,
519		index_vars: usize,
520		tensor_query: MultilinearQueryRef<P>,
521		_scratch_space: Option<&mut [P]>,
522		evals_0: &mut [P],
523		evals_1: &mut [P],
524	) -> Result<(), Error> {
525		if evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
526			|| evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
527		{
528			bail!(Error::IncorrectDestSliceLengths);
529		}
530
531		match multilinear {
532			SumcheckMultilinear::Transparent { multilinear, .. } => {
533				if tensor_query.n_vars() == 0 {
534					multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals_0)?;
535					multilinear.subcube_evals(
536						subcube_vars,
537						subcube_index | 1 << index_vars,
538						0,
539						evals_1,
540					)?;
541				} else {
542					multilinear.subcube_partial_high_evals(
543						tensor_query,
544						subcube_vars,
545						subcube_index,
546						evals_0,
547					)?;
548					multilinear.subcube_partial_high_evals(
549						tensor_query,
550						subcube_vars,
551						subcube_index | 1 << index_vars,
552						evals_1,
553					)?;
554				}
555			}
556
557			SumcheckMultilinear::Folded {
558				large_field_folded_evals: evals,
559				suffix_eval,
560			} => {
561				if subcube_vars >= P::LOG_WIDTH {
562					let packed_log_size = subcube_vars - P::LOG_WIDTH;
563					let offset_0 = subcube_index << packed_log_size;
564					let offset_1 = offset_0 | 1 << (index_vars + packed_log_size);
565					let packed_len_0 =
566						(1 << packed_log_size).min(evals.len().saturating_sub(offset_0));
567					let packed_len_1 =
568						(1 << packed_log_size).min(evals.len().saturating_sub(offset_1));
569
570					if packed_len_0 > 0 {
571						evals_0[..packed_len_0].copy_from_slice(&evals[offset_0..][..packed_len_0]);
572					}
573
574					if packed_len_1 > 0 {
575						evals_1[..packed_len_1].copy_from_slice(&evals[offset_1..][..packed_len_1]);
576					}
577
578					evals_0[packed_len_0..].fill(P::broadcast(*suffix_eval));
579					evals_1[packed_len_1..].fill(P::broadcast(*suffix_eval));
580				} else {
581					let mut evals_0_packed = P::zero();
582					let mut evals_1_packed = P::zero();
583
584					for i in 0..1 << subcube_vars {
585						let index_0 = subcube_index << subcube_vars | i;
586						let index_1 = index_0 | 1 << (index_vars + subcube_vars);
587						evals_0_packed.set(
588							i,
589							get_packed_slice_checked(evals, index_0).unwrap_or(*suffix_eval),
590						);
591						evals_1_packed.set(
592							i,
593							get_packed_slice_checked(evals, index_1).unwrap_or(*suffix_eval),
594						);
595					}
596
597					*evals_0.first_mut().expect("non-empty evals_0") = evals_0_packed;
598					*evals_1.first_mut().expect("non-empty evals_1") = evals_1_packed;
599				}
600			}
601		}
602
603		Ok(())
604	}
605}