binius_hal/
sumcheck_round_calculation.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! Functions that calculate the sumcheck round evaluations.
4//!
5//! This is one of the core computational tasks in the sumcheck proving algorithm.
6
7use std::iter;
8
9use binius_field::{
10	packed::get_packed_slice_checked, Field, PackedExtension, PackedField, PackedSubfield,
11};
12use binius_math::{
13	extrapolate_lines, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
14	MultilinearQueryRef, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{izip, Either, Itertools};
20use stackalloc::stackalloc_with_iter;
21
22use crate::{
23	common::{subcube_vars_for_bits, MAX_SRC_SUBCUBE_LOG_BITS},
24	Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
25};
26
27trait SumcheckMultilinearAccess<P: PackedField> {
28	/// The size of `Vec<P>` scratchspace used by [`subcube_evaluations`], if any.
29	fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize>;
30
31	/// A way to obtain multilinear evaluations during sumcheck.
32	///
33	/// Sumcheck is conducted over boolean hypercubes represented in little endian form
34	/// (faster strides correspond to lower variable indexes). Evaluation order may proceed both
35	/// from lowest variable to highest, as well as in reverse. The $n$-dimensional evaluation
36	/// hypercube can be split into subcubes of `subcube_vars` by substituting higher variables
37	/// for the little endian binary representation of some index.
38	///
39	/// Assume $r$ round challenges have been sampled already. A sumcheck multilinear then
40	/// is either an $n + r$-variate transparent where $r$ of its variables (lowest or highest)
41	/// are to be projected onto round challenges, or $n$-variate folded multilinear where this
42	/// projection has already taken place. It is further split into two $n-1$ variate subcubes
43	/// by substituting 0 and 1 for its lowest or highest variable, depending on evaluation order.
44	///
45	/// Assuming `subcube_vars + index_vars = n-1` holds, we substitute the binary representation
46	/// of `subcube_index` into higher indexed variables. Note that this sub-subcube ordering
47	/// _does not_ depend on evaluation order.
48	///
49	/// Indexed subcube evaluations are written into `evals_0` and `evals_1` slices, where scalar
50	/// order corresponds to the lower `P::LOG_WIDTH` variables of the `subcube_vars`-variate hypercube.
51	///
52	/// The method can potentially require a `&mut [P]` scratch space, whose length is given by a query
53	/// to [`scratch_space_len`] and should be uniquely determined by `subcube_vars`.
54	///
55	/// ## Arguments
56	///
57	/// * `subcube_vars`  - the number of variables in the sub-subcube to evaluate over
58	/// * `subcube_index` - the index of the subcube within the $n-1$-variate hypercube
59	/// * `index_vars`    - number of bits in the `subcube_index`
60	/// * `tensor_query`  - multilinear query of pre-switchover challenges (empty if all folded)
61	/// * `scratch_space` - optional scratch space
62	/// * `evals_0`       - `subcube_vars`-variate hypercube with current variables substituted for 0
63	/// * `evals_1`       - `subcube_vars`-variate hypercube with current variables substituted for 1
64	#[allow(clippy::too_many_arguments)]
65	fn subcube_evaluations<M: MultilinearPoly<P>>(
66		&self,
67		multilinear: &SumcheckMultilinear<P, M>,
68		subcube_vars: usize,
69		subcube_index: usize,
70		index_vars: usize,
71		tensor_query: MultilinearQueryRef<P>,
72		scratch_space: Option<&mut [P]>,
73		evals_0: &mut [P],
74		evals_1: &mut [P],
75	) -> Result<(), Error>;
76}
77
78/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
79///
80/// See [`calculate_first_round_evals`] for an optimized version of this method
81/// that works over small fields in the first round.
82pub(crate) fn calculate_round_evals<FDomain, F, P, M, Evaluator, Composition>(
83	evaluation_order: EvaluationOrder,
84	n_vars: usize,
85	tensor_query: Option<MultilinearQueryRef<P>>,
86	multilinears: &[SumcheckMultilinear<P, M>],
87	evaluators: &[Evaluator],
88	finite_evaluation_points: &[FDomain],
89) -> Result<Vec<RoundEvals<F>>, Error>
90where
91	FDomain: Field,
92	F: Field,
93	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
94	M: MultilinearPoly<P> + Sync,
95	Evaluator: SumcheckEvaluator<P, Composition> + Sync,
96	Composition: CompositionPoly<P>,
97{
98	assert!(n_vars > 0, "Computing round evaluations requires at least a single variable.");
99
100	let empty_query = MultilinearQuery::with_capacity(0);
101	let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref());
102
103	match evaluation_order {
104		EvaluationOrder::LowToHigh => calculate_round_evals_with_access(
105			LowToHighAccess,
106			n_vars,
107			tensor_query,
108			multilinears,
109			evaluators,
110			finite_evaluation_points,
111		),
112		EvaluationOrder::HighToLow => calculate_round_evals_with_access(
113			HighToLowAccess,
114			n_vars,
115			tensor_query,
116			multilinears,
117			evaluators,
118			finite_evaluation_points,
119		),
120	}
121}
122
123fn calculate_round_evals_with_access<FDomain, F, P, M, Evaluator, Access, Composition>(
124	access: Access,
125	n_vars: usize,
126	tensor_query: MultilinearQueryRef<P>,
127	multilinears: &[SumcheckMultilinear<P, M>],
128	evaluators: &[Evaluator],
129	nontrivial_evaluation_points: &[FDomain],
130) -> Result<Vec<RoundEvals<F>>, Error>
131where
132	FDomain: Field,
133	F: Field,
134	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
135	M: MultilinearPoly<P> + Sync,
136	Evaluator: SumcheckEvaluator<P, Composition> + Sync,
137	Access: SumcheckMultilinearAccess<P> + Sync,
138	Composition: CompositionPoly<P>,
139{
140	let n_multilinears = multilinears.len();
141	let n_round_evals = evaluators
142		.iter()
143		.map(|evaluator| evaluator.eval_point_indices().len());
144
145	// Compute the union of all evaluation point index ranges.
146	let eval_point_indices = evaluators
147		.iter()
148		.map(|evaluator| evaluator.eval_point_indices())
149		.reduce(|range1, range2| range1.start.min(range2.start)..range1.end.max(range2.end))
150		.unwrap_or(0..0);
151
152	// Check that finite evaluation points  are of correct length (accounted for 0, 1 & infinity point).
153	if nontrivial_evaluation_points.len() != eval_point_indices.end.saturating_sub(3) {
154		bail!(Error::IncorrectNontrivialEvalPointsLength);
155	}
156
157	// Here we assume that at least one multilinear would be "full"
158	// REVIEW: come up with a better heuristic
159	let subcube_vars = subcube_vars_for_bits::<P>(
160		MAX_SRC_SUBCUBE_LOG_BITS,
161		n_vars - 1,
162		tensor_query.n_vars(),
163		n_vars - 1,
164	);
165
166	let subcube_count_by_evaluator = evaluators
167		.iter()
168		.map(|evaluator| {
169			((1 << (n_vars - 1)) - evaluator.const_eval_suffix()).div_ceil(1 << subcube_vars)
170		})
171		.collect::<Vec<_>>();
172
173	let mut subcube_count_by_multilinear = vec![0; n_multilinears];
174
175	for (&evaluator_subcube_count, evaluator) in izip!(&subcube_count_by_evaluator, evaluators) {
176		let used_vars = evaluator.composition().expression().vars_usage();
177
178		for (multilinear_subcube_count, usage_flag) in
179			izip!(&mut subcube_count_by_multilinear, used_vars)
180		{
181			if usage_flag {
182				*multilinear_subcube_count =
183					(*multilinear_subcube_count).max(evaluator_subcube_count);
184			}
185		}
186	}
187
188	let index_vars = n_vars - 1 - subcube_vars;
189	let packed_accumulators = (0..1 << index_vars)
190		.into_par_iter()
191		.try_fold(
192			|| ParFoldStates::new(&access, n_multilinears, n_round_evals.clone(), subcube_vars),
193			|mut par_fold_states, subcube_index| {
194				let ParFoldStates {
195					multilinear_evals,
196					scratch_space,
197					round_evals,
198				} = &mut par_fold_states;
199
200				for (multilinear, evals, &subcube_count) in
201					izip!(multilinears, multilinear_evals.iter_mut(), &subcube_count_by_multilinear)
202				{
203					if subcube_index < subcube_count {
204						access.subcube_evaluations(
205							multilinear,
206							subcube_vars,
207							subcube_index,
208							index_vars,
209							tensor_query,
210							scratch_space.as_deref_mut(),
211							&mut evals.evals_0,
212							&mut evals.evals_1,
213						)?;
214					}
215				}
216
217				// Proceed by evaluation point first to share interpolation work between evaluators.
218				for eval_point_index in eval_point_indices.clone() {
219					// Infinity point requires special evaluation rules
220					let is_infinity_point = eval_point_index == 2;
221
222					// Multilinears are evaluated at a point t via linear interpolation:
223					//   f(z, xs) = f(0, xs) + z * (f(1, xs) - f(0, xs))
224					// The first three points are treated specially:
225					//   index 0 - z = 0   => f(z, xs) = f(0, xs)
226					//   index 1 - z = 1   => f(z, xs) = f(1, xs)
227					//   index 2 = z = inf => f(inf, xs) = high (f(0, xs) + z * (f(1, xs) - f(0, xs))) =
228					//                                   = f(1, xs) - f(0, xs)
229					//   index 3 and above - remaining finite evaluation points
230					let evals_z_iter =
231						izip!(multilinear_evals.iter_mut(), &subcube_count_by_multilinear).map(
232							|(evals, &subcube_count)| match eval_point_index {
233								// This multilinear is not accessed, return arbitrary slice
234								_ if subcube_index >= subcube_count => evals.evals_0.as_slice(),
235								0 => evals.evals_0.as_slice(),
236								1 => evals.evals_1.as_slice(),
237								2 => {
238									// infinity point
239									izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
240										.for_each(|(eval_z, &eval_0, &eval_1)| {
241											*eval_z = eval_1 - eval_0;
242										});
243
244									evals.evals_z.as_slice()
245								}
246								3.. => {
247									// Account for the gap occupied by the 0, 1 & infinity point
248									let eval_point =
249										nontrivial_evaluation_points[eval_point_index - 3];
250									let eval_point_broadcast =
251										<PackedSubfield<P, FDomain>>::broadcast(eval_point);
252
253									izip!(&mut evals.evals_z, &evals.evals_0, &evals.evals_1)
254										.for_each(|(eval_z, &eval_0, &eval_1)| {
255											// This is logically the same as calling
256											// `binius_math::univariate::extrapolate_line`, except that we do
257											// not repeat the broadcast of the subfield element to a packed
258											// subfield.
259											*eval_z = P::cast_ext(extrapolate_lines(
260												P::cast_base(eval_0),
261												P::cast_base(eval_1),
262												eval_point_broadcast,
263											));
264										});
265
266									evals.evals_z.as_slice()
267								}
268							},
269						);
270
271					let row_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
272					stackalloc_with_iter(n_multilinears, evals_z_iter, |evals_z| {
273						let evals_z = RowsBatchRef::new(evals_z, row_len);
274
275						for (evaluator, round_evals, &subcube_count) in
276							izip!(evaluators, round_evals.iter_mut(), &subcube_count_by_evaluator)
277						{
278							let eval_point_indices = evaluator.eval_point_indices();
279							if !eval_point_indices.contains(&eval_point_index)
280								|| subcube_index >= subcube_count
281							{
282								continue;
283							}
284
285							round_evals[eval_point_index - eval_point_indices.start] += evaluator
286								.process_subcube_at_eval_point(
287									subcube_vars,
288									subcube_index,
289									is_infinity_point,
290									&evals_z,
291								);
292						}
293					});
294				}
295
296				Ok(par_fold_states)
297			},
298		)
299		.map(|states: Result<ParFoldStates<P>, Error>| -> Result<_, Error> {
300			Ok(states?.round_evals)
301		})
302		// Simply sum up the fold partitions.
303		.try_reduce(
304			|| {
305				evaluators
306					.iter()
307					.map(|evaluator| vec![P::zero(); evaluator.eval_point_indices().len()])
308					.collect()
309			},
310			|lhs, rhs| {
311				let sum = izip!(lhs, rhs)
312					.map(|(mut lhs_vals, rhs_vals)| {
313						for (lhs_val, rhs_val) in lhs_vals.iter_mut().zip(rhs_vals) {
314							*lhs_val += rhs_val;
315						}
316						lhs_vals
317					})
318					.collect();
319				Ok(sum)
320			},
321		)?;
322
323	let round_evals = izip!(packed_accumulators, evaluators, subcube_count_by_evaluator)
324		.map(|(packed_round_evals, evaluator, subcube_count)| {
325			let mut round_evals = packed_round_evals
326				.into_iter()
327				// Truncate subcubes smaller than packing width.
328				.map(|packed_round_eval| packed_round_eval.iter().take(1 << subcube_vars).sum())
329				.collect::<Vec<F>>();
330
331			let const_eval_suffix = (1 << n_vars) - (subcube_count << subcube_vars);
332			for (eval_point_index, round_eval) in
333				izip!(eval_point_indices.clone(), &mut round_evals)
334			{
335				let is_infinity_point = eval_point_index == 2;
336				*round_eval +=
337					evaluator.process_constant_eval_suffix(const_eval_suffix, is_infinity_point);
338			}
339
340			RoundEvals(round_evals)
341		})
342		.collect();
343
344	Ok(round_evals)
345}
346
347// Evals of a single multilinear over a subcube, at 0/1 and some interpolated point.
348#[derive(Debug)]
349struct MultilinearEvals<P: PackedField> {
350	evals_0: Vec<P>,
351	evals_1: Vec<P>,
352	evals_z: Vec<P>,
353}
354
355impl<P: PackedField> MultilinearEvals<P> {
356	fn new(subcube_vars: usize) -> Self {
357		let len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
358		Self {
359			evals_0: zeroed_vec(len),
360			evals_1: zeroed_vec(len),
361			evals_z: zeroed_vec(len),
362		}
363	}
364}
365
366/// Parallel fold state, consisting of scratch area and result accumulator.
367#[derive(Debug)]
368struct ParFoldStates<P: PackedField> {
369	// Evaluations at 0, 1 and domain points, per MLE. Scratch space.
370	multilinear_evals: Vec<MultilinearEvals<P>>,
371
372	// Additional scratch space.
373	scratch_space: Option<Vec<P>>,
374
375	// Accumulated sums of evaluations over univariate domain.
376	//
377	// Each element of the outer vector corresponds to one composite polynomial. Each element of
378	// an inner vector contains the evaluations at different points.
379	round_evals: Vec<Vec<P>>,
380}
381
382impl<P: PackedField> ParFoldStates<P> {
383	fn new(
384		access: &impl SumcheckMultilinearAccess<P>,
385		n_multilinears: usize,
386		n_round_evals: impl Iterator<Item = usize>,
387		subcube_vars: usize,
388	) -> Self {
389		Self {
390			multilinear_evals: (0..n_multilinears)
391				.map(|_| MultilinearEvals::new(subcube_vars))
392				.collect(),
393			scratch_space: access
394				.scratch_space_len(subcube_vars)
395				.map(|len| zeroed_vec(len)),
396			round_evals: n_round_evals
397				.map(|n_round_evals| zeroed_vec(n_round_evals))
398				.collect(),
399		}
400	}
401}
402
403#[derive(Debug)]
404struct LowToHighAccess;
405
406impl<P: PackedField> SumcheckMultilinearAccess<P> for LowToHighAccess {
407	fn scratch_space_len(&self, subcube_vars: usize) -> Option<usize> {
408		// We need to sample evaluations at both 0 & 1 prior to deinterleaving, thus +1.
409		Some(1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH))
410	}
411
412	fn subcube_evaluations<M: MultilinearPoly<P>>(
413		&self,
414		multilinear: &SumcheckMultilinear<P, M>,
415		subcube_vars: usize,
416		subcube_index: usize,
417		_index_vars: usize,
418		tensor_query: MultilinearQueryRef<P>,
419		scratch_space: Option<&mut [P]>,
420		evals_0: &mut [P],
421		evals_1: &mut [P],
422	) -> Result<(), Error> {
423		let Some(scratch_space) = scratch_space else {
424			bail!(Error::NoScratchSpace);
425		};
426
427		if scratch_space.len() != 1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH)
428			|| evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
429			|| evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
430		{
431			bail!(Error::IncorrectDestSliceLengths);
432		}
433
434		match multilinear {
435			SumcheckMultilinear::Transparent { multilinear, .. } => {
436				if tensor_query.n_vars() == 0 {
437					multilinear.subcube_evals(subcube_vars + 1, subcube_index, 0, scratch_space)?
438				} else {
439					multilinear.subcube_partial_low_evals(
440						tensor_query,
441						subcube_vars + 1,
442						subcube_index,
443						scratch_space,
444					)?
445				}
446			}
447
448			SumcheckMultilinear::Folded {
449				large_field_folded_evals: evals,
450			} => {
451				if subcube_vars + 1 >= P::LOG_WIDTH {
452					let packed_log_size = subcube_vars + 1 - P::LOG_WIDTH;
453					let offset = subcube_index << packed_log_size;
454					let packed_len = (1 << packed_log_size).min(evals.len().saturating_sub(offset));
455					if packed_len > 0 {
456						scratch_space[..packed_len]
457							.copy_from_slice(&evals[offset..offset + packed_len]);
458					}
459					scratch_space[packed_len..].fill(P::zero());
460				} else {
461					let mut only_packed = P::zero();
462
463					for i in 0..1 << (subcube_vars + 1) {
464						let index = subcube_index << (subcube_vars + 1) | i;
465						only_packed.set(
466							i,
467							get_packed_slice_checked(evals, index).unwrap_or(P::Scalar::ZERO),
468						);
469					}
470
471					*scratch_space.first_mut().expect("non-empty scratch space") = only_packed;
472				}
473			}
474		}
475
476		// Evaluations at 0 & 1 are interleaved (the substituted variable is the lowest one), need
477		// to deinterleave them first. This requires scratch space to enable simple linear time algorithm.
478		let zeros = P::default();
479		let interleaved_tuples = if scratch_space.len() == 1 {
480			Either::Left(iter::once((scratch_space.first().expect("len==1"), &zeros)))
481		} else {
482			Either::Right(scratch_space.iter().tuples())
483		};
484
485		for ((&interleaved_0, &interleaved_1), evals_0, evals_1) in
486			izip!(interleaved_tuples, evals_0, evals_1)
487		{
488			let (deinterleaved_0, deinterleaved_1) = if P::LOG_WIDTH > 0 {
489				P::unzip(interleaved_0, interleaved_1, 0)
490			} else {
491				(interleaved_0, interleaved_1)
492			};
493
494			*evals_0 = deinterleaved_0;
495			*evals_1 = deinterleaved_1;
496		}
497
498		Ok(())
499	}
500}
501
502#[derive(Debug)]
503struct HighToLowAccess;
504
505impl<P: PackedField> SumcheckMultilinearAccess<P> for HighToLowAccess {
506	fn scratch_space_len(&self, _subcube_vars: usize) -> Option<usize> {
507		None
508	}
509
510	fn subcube_evaluations<M: MultilinearPoly<P>>(
511		&self,
512		multilinear: &SumcheckMultilinear<P, M>,
513		subcube_vars: usize,
514		subcube_index: usize,
515		index_vars: usize,
516		tensor_query: MultilinearQueryRef<P>,
517		_scratch_space: Option<&mut [P]>,
518		evals_0: &mut [P],
519		evals_1: &mut [P],
520	) -> Result<(), Error> {
521		if evals_0.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
522			|| evals_1.len() != 1 << subcube_vars.saturating_sub(P::LOG_WIDTH)
523		{
524			bail!(Error::IncorrectDestSliceLengths);
525		}
526
527		match multilinear {
528			SumcheckMultilinear::Transparent { multilinear, .. } => {
529				if tensor_query.n_vars() == 0 {
530					multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals_0)?;
531					multilinear.subcube_evals(
532						subcube_vars,
533						subcube_index | 1 << index_vars,
534						0,
535						evals_1,
536					)?;
537				} else {
538					multilinear.subcube_partial_high_evals(
539						tensor_query,
540						subcube_vars,
541						subcube_index,
542						evals_0,
543					)?;
544					multilinear.subcube_partial_high_evals(
545						tensor_query,
546						subcube_vars,
547						subcube_index | 1 << index_vars,
548						evals_1,
549					)?;
550				}
551			}
552
553			SumcheckMultilinear::Folded {
554				large_field_folded_evals: evals,
555			} => {
556				if subcube_vars >= P::LOG_WIDTH {
557					let packed_log_size = subcube_vars - P::LOG_WIDTH;
558					let offset_0 = subcube_index << packed_log_size;
559					let offset_1 = offset_0 | 1 << (index_vars + packed_log_size);
560					let packed_len_0 =
561						(1 << packed_log_size).min(evals.len().saturating_sub(offset_0));
562					let packed_len_1 =
563						(1 << packed_log_size).min(evals.len().saturating_sub(offset_1));
564
565					if packed_len_0 > 0 {
566						evals_0[..packed_len_0].copy_from_slice(&evals[offset_0..][..packed_len_0]);
567					}
568
569					if packed_len_1 > 0 {
570						evals_1[..packed_len_1].copy_from_slice(&evals[offset_1..][..packed_len_1]);
571					}
572
573					evals_0[packed_len_0..].fill(P::zero());
574					evals_1[packed_len_1..].fill(P::zero());
575				} else {
576					let mut evals_0_packed = P::zero();
577					let mut evals_1_packed = P::zero();
578
579					for i in 0..1 << subcube_vars {
580						let index_0 = subcube_index << subcube_vars | i;
581						let index_1 = index_0 | 1 << (index_vars + subcube_vars);
582						evals_0_packed.set(
583							i,
584							get_packed_slice_checked(evals, index_0).unwrap_or(P::Scalar::ZERO),
585						);
586						evals_1_packed.set(
587							i,
588							get_packed_slice_checked(evals, index_1).unwrap_or(P::Scalar::ZERO),
589						);
590					}
591
592					*evals_0.first_mut().expect("non-empty evals_0") = evals_0_packed;
593					*evals_1.first_mut().expect("non-empty evals_1") = evals_1_packed;
594				}
595			}
596		}
597
598		Ok(())
599	}
600}