binius_core/protocols/sumcheck/prove/
prover_state.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter,
5	sync::atomic::{AtomicBool, Ordering},
6};
7
8use binius_field::{util::powers, Field, PackedExtension, PackedField};
9use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
10use binius_math::{
11	evaluate_univariate, fold_left_lerp_inplace, fold_right_lerp, CompositionPoly, EvaluationOrder,
12	MultilinearPoly, MultilinearQuery,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17use getset::CopyGetters;
18use itertools::izip;
19use tracing::instrument;
20
21use crate::{
22	polynomial::Error as PolynomialError,
23	protocols::sumcheck::{
24		common::{determine_switchovers, equal_n_vars_check, RoundCoeffs},
25		error::Error,
26	},
27};
28
29pub trait SumcheckInterpolator<F: Field> {
30	/// Given evaluations of the round polynomial, interpolate and return monomial coefficients
31	///
32	/// ## Arguments
33	///
34	/// * `round_evals`: the computed evaluations of the round polynomial
35	fn round_evals_to_coeffs(
36		&self,
37		last_sum: F,
38		round_evals: Vec<F>,
39	) -> Result<Vec<F>, PolynomialError>;
40}
41
42#[derive(Debug)]
43enum ProverStateCoeffsOrSums<F: Field> {
44	Coeffs(Vec<RoundCoeffs<F>>),
45	Sums(Vec<F>),
46}
47
48/// The stored state of a sumcheck prover, which encapsulates common implementation logic.
49///
50/// We expect that CPU sumcheck provers will internally maintain a [`ProverState`] instance and
51/// customize the sumcheck logic through different [`SumcheckEvaluator`] implementations passed to
52/// the common state object.
53#[derive(Debug, CopyGetters)]
54pub struct ProverState<'a, FDomain, P, M, Backend>
55where
56	FDomain: Field,
57	P: PackedField,
58	M: MultilinearPoly<P> + Send + Sync,
59	Backend: ComputationBackend,
60{
61	/// The number of variables in the folded multilinears. This value decrements each round the
62	/// state is folded.
63	#[getset(get_copy = "pub")]
64	n_vars: usize,
65	#[getset(get_copy = "pub")]
66	evaluation_order: EvaluationOrder,
67	multilinears: Vec<SumcheckMultilinear<P, M>>,
68	nontrivial_evaluation_points: Vec<FDomain>,
69	challenges: Vec<P::Scalar>,
70	tensor_query: Option<MultilinearQuery<P>>,
71	last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
72	backend: &'a Backend,
73}
74
75impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
76where
77	FDomain: Field,
78	F: Field,
79	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
80	M: MultilinearPoly<P> + Send + Sync,
81	Backend: ComputationBackend,
82{
83	pub fn new(
84		evaluation_order: EvaluationOrder,
85		multilinears: Vec<M>,
86		claimed_sums: Vec<F>,
87		nontrivial_evaluation_points: Vec<FDomain>,
88		switchover_fn: impl Fn(usize) -> usize,
89		backend: &'a Backend,
90	) -> Result<Self, Error> {
91		let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
92		Self::new_with_switchover_rounds(
93			evaluation_order,
94			multilinears,
95			&switchover_rounds,
96			claimed_sums,
97			nontrivial_evaluation_points,
98			backend,
99		)
100	}
101
102	#[instrument(
103		skip_all,
104		level = "debug",
105		name = "ProverState::new_with_switchover_rounds"
106	)]
107	pub fn new_with_switchover_rounds(
108		evaluation_order: EvaluationOrder,
109		multilinears: Vec<M>,
110		switchover_rounds: &[usize],
111		claimed_sums: Vec<F>,
112		nontrivial_evaluation_points: Vec<FDomain>,
113		backend: &'a Backend,
114	) -> Result<Self, Error> {
115		let n_vars = equal_n_vars_check(&multilinears)?;
116
117		if multilinears.len() != switchover_rounds.len() {
118			bail!(Error::MultilinearSwitchoverSizeMismatch);
119		}
120
121		let max_switchover_round = switchover_rounds.iter().copied().max().unwrap_or_default();
122
123		let multilinears = iter::zip(multilinears, switchover_rounds)
124			.map(|(multilinear, &switchover_round)| SumcheckMultilinear::Transparent {
125				multilinear,
126				switchover_round,
127			})
128			.collect();
129
130		let tensor_query = MultilinearQuery::with_capacity(max_switchover_round + 1);
131
132		Ok(Self {
133			n_vars,
134			evaluation_order,
135			multilinears,
136			nontrivial_evaluation_points,
137			challenges: Vec::new(),
138			tensor_query: Some(tensor_query),
139			last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
140			backend,
141		})
142	}
143
144	#[instrument(skip_all, name = "ProverState::fold", level = "debug")]
145	pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
146		if self.n_vars == 0 {
147			bail!(Error::ExpectedFinish);
148		}
149
150		// Update the stored multilinear sums.
151		match self.last_coeffs_or_sums {
152			ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
153				let new_sums = round_coeffs
154					.par_iter()
155					.map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
156					.collect();
157				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
158			}
159			ProverStateCoeffsOrSums::Sums(_) => {
160				bail!(Error::ExpectedExecution);
161			}
162		}
163
164		// Update the tensor query.
165		match self.evaluation_order {
166			EvaluationOrder::LowToHigh => self.challenges.push(challenge),
167			EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
168		}
169
170		if let Some(tensor_query) = self.tensor_query.take() {
171			self.tensor_query = match self.evaluation_order {
172				EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
173				// REVIEW: not spending effort to come up with an inplace update method here, as the
174				//         future of switchover is somewhat unclear in light of univariate skip, and
175				//         switchover tensors are small-ish anyway.
176				EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
177			}
178		}
179
180		// Use Relaxed ordering for writes and the read, because:
181		// * all writes can only update this value in the same direction of false->true
182		// * the barrier at the end of rayon "parallel for" is a big enough synchronization point to be Relaxed about memory ordering of accesses to this Atomic.
183		let any_transparent_left = AtomicBool::new(false);
184		self.multilinears
185			.par_iter_mut()
186			.try_for_each(|multilinear| {
187				match multilinear {
188					SumcheckMultilinear::Transparent {
189						multilinear: inner_multilinear,
190						ref mut switchover_round,
191					} => {
192						if *switchover_round == 0 {
193							let tensor_query = self.tensor_query.as_ref()
194							.expect(
195								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
196							);
197
198							// At switchover we partially evaluate the multilinear at an expanded tensor query.
199							let large_field_folded_evals = match self.evaluation_order {
200								EvaluationOrder::LowToHigh => inner_multilinear
201									.evaluate_partial_low(tensor_query.to_ref())?
202									.into_evals(),
203								EvaluationOrder::HighToLow => inner_multilinear
204									.evaluate_partial_high(tensor_query.to_ref())?
205									.into_evals(),
206							};
207
208							*multilinear = SumcheckMultilinear::Folded {
209								large_field_folded_evals,
210							};
211						} else {
212							*switchover_round -= 1;
213							any_transparent_left.store(true, Ordering::Relaxed);
214						}
215					}
216					SumcheckMultilinear::Folded {
217						ref mut large_field_folded_evals,
218					} => {
219						// Post-switchover, we perform single variable folding (linear interpolation).
220
221						match self.evaluation_order {
222							// Lerp folding in low-to-high evaluation order can be made inplace, but not
223							// easily so if multithreading is desired.
224							EvaluationOrder::LowToHigh => {
225								let mut new_large_field_folded_evals =
226									zeroed_vec(1 << self.n_vars.saturating_sub(1 + P::LOG_WIDTH));
227
228								fold_right_lerp(
229									&*large_field_folded_evals,
230									self.n_vars,
231									challenge,
232									&mut new_large_field_folded_evals,
233								)?;
234
235								*large_field_folded_evals = new_large_field_folded_evals;
236							}
237
238							// High-to-low evaluation order allows trivial inplace multithreaded folding.
239							EvaluationOrder::HighToLow => {
240								// REVIEW: note that this method is currently _not_ multithreaded, as
241								//         traces are usually sufficiently wide
242								fold_left_lerp_inplace(
243									large_field_folded_evals,
244									self.n_vars,
245									challenge,
246								)?;
247							}
248						}
249					}
250				};
251				Ok::<(), Error>(())
252			})?;
253
254		if !any_transparent_left.load(Ordering::Relaxed) {
255			self.tensor_query = None;
256		}
257
258		self.n_vars -= 1;
259		Ok(())
260	}
261
262	pub fn finish(self) -> Result<Vec<F>, Error> {
263		match self.last_coeffs_or_sums {
264			ProverStateCoeffsOrSums::Coeffs(_) => {
265				bail!(Error::ExpectedFold);
266			}
267			ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
268				0 => {}
269				_ => bail!(Error::ExpectedExecution),
270			},
271		};
272
273		self.multilinears
274			.into_iter()
275			.map(|multilinear| {
276				match multilinear {
277					SumcheckMultilinear::Transparent {
278						multilinear: inner_multilinear,
279						..
280					} => {
281						let tensor_query = self.tensor_query.as_ref()
282							.expect(
283								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
284							);
285						inner_multilinear.evaluate(tensor_query.to_ref())
286					}
287					SumcheckMultilinear::Folded {
288						large_field_folded_evals,
289					} => Ok(large_field_folded_evals
290						.first()
291						.expect("exactly one packed field element left after folding")
292						.get(0)),
293				}
294				.map_err(Error::MathError)
295			})
296			.collect()
297	}
298
299	/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
300	#[instrument(skip_all, level = "debug")]
301	pub fn calculate_round_evals<Evaluator, Composition>(
302		&self,
303		evaluators: &[Evaluator],
304	) -> Result<Vec<RoundEvals<F>>, Error>
305	where
306		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
307		Composition: CompositionPoly<P>,
308	{
309		Ok(self.backend.sumcheck_compute_round_evals(
310			self.evaluation_order,
311			self.n_vars,
312			self.tensor_query.as_ref().map(Into::into),
313			&self.multilinears,
314			evaluators,
315			&self.nontrivial_evaluation_points,
316		)?)
317	}
318
319	/// Calculate the batched round coefficients from the domain evaluations.
320	///
321	/// This both performs the polynomial interpolation over the evaluations and the mixing with
322	/// the batching coefficient.
323	pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
324		&mut self,
325		interpolators: &[Interpolator],
326		batch_coeff: F,
327		evals: Vec<RoundEvals<F>>,
328	) -> Result<RoundCoeffs<F>, Error> {
329		let coeffs = match self.last_coeffs_or_sums {
330			ProverStateCoeffsOrSums::Coeffs(_) => {
331				bail!(Error::ExpectedFold);
332			}
333			ProverStateCoeffsOrSums::Sums(ref sums) => {
334				if interpolators.len() != sums.len() {
335					bail!(Error::IncorrectNumberOfEvaluators {
336						expected: sums.len(),
337					});
338				}
339
340				let coeffs = izip!(interpolators, sums, evals)
341					.map(|(evaluator, &sum, RoundEvals(evals))| {
342						let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
343						Ok::<_, Error>(RoundCoeffs(coeffs))
344					})
345					.collect::<Result<Vec<_>, _>>()?;
346				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
347				coeffs
348			}
349		};
350
351		let batched_coeffs = coeffs
352			.into_iter()
353			.zip(powers(batch_coeff))
354			.map(|(coeffs, scalar)| coeffs * scalar)
355			.fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
356
357		Ok(batched_coeffs)
358	}
359}