binius_core/protocols/sumcheck/prove/
prover_state.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedExtension, PackedField, util::powers};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6	CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery, evaluate_univariate,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15	polynomial::Error as PolynomialError,
16	protocols::sumcheck::{common::RoundCoeffs, error::Error},
17};
18
19pub trait SumcheckInterpolator<F: Field> {
20	/// Given evaluations of the round polynomial, interpolate and return monomial coefficients
21	///
22	/// ## Arguments
23	///
24	/// * `round_evals`: the computed evaluations of the round polynomial
25	fn round_evals_to_coeffs(
26		&self,
27		last_sum: F,
28		round_evals: Vec<F>,
29	) -> Result<Vec<F>, PolynomialError>;
30}
31
32#[derive(Debug)]
33enum ProverStateCoeffsOrSums<F: Field> {
34	Coeffs(Vec<RoundCoeffs<F>>),
35	Sums(Vec<F>),
36}
37
38/// The stored state of a sumcheck prover, which encapsulates common implementation logic.
39///
40/// We expect that CPU sumcheck provers will internally maintain a [`ProverState`] instance and
41/// customize the sumcheck logic through different [`SumcheckEvaluator`] implementations passed to
42/// the common state object.
43#[derive(Debug, CopyGetters)]
44pub struct ProverState<'a, FDomain, P, M, Backend>
45where
46	FDomain: Field,
47	P: PackedField,
48	M: MultilinearPoly<P> + Send + Sync,
49	Backend: ComputationBackend,
50{
51	/// The number of variables in the folded multilinears. This value decrements each round the
52	/// state is folded.
53	#[getset(get_copy = "pub")]
54	n_vars: usize,
55	#[getset(get_copy = "pub")]
56	evaluation_order: EvaluationOrder,
57	multilinears: Vec<SumcheckMultilinear<P, M>>,
58	nontrivial_evaluation_points: Vec<FDomain>,
59	challenges: Vec<P::Scalar>,
60	tensor_query: Option<MultilinearQuery<P>>,
61	last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
62	backend: &'a Backend,
63}
64
65impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
66where
67	FDomain: Field,
68	F: Field,
69	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
70	M: MultilinearPoly<P> + Send + Sync,
71	Backend: ComputationBackend,
72{
73	#[instrument(skip_all, level = "debug", name = "ProverState::new")]
74	pub fn new(
75		evaluation_order: EvaluationOrder,
76		n_vars: usize,
77		multilinears: Vec<SumcheckMultilinear<P, M>>,
78		claimed_sums: Vec<F>,
79		nontrivial_evaluation_points: Vec<FDomain>,
80		backend: &'a Backend,
81	) -> Result<Self, Error> {
82		for multilinear in &multilinears {
83			match *multilinear {
84				SumcheckMultilinear::Transparent {
85					ref multilinear,
86					const_suffix: (_, suffix_len),
87					..
88				} => {
89					if multilinear.n_vars() != n_vars {
90						bail!(Error::NumberOfVariablesMismatch);
91					}
92
93					if suffix_len > 1 << n_vars {
94						bail!(Error::IncorrectConstSuffixes);
95					}
96				}
97
98				SumcheckMultilinear::Folded {
99					large_field_folded_evals: ref evals,
100					..
101				} => {
102					if evals.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
103						bail!(Error::IncorrectConstSuffixes);
104					}
105				}
106			}
107		}
108
109		let tensor_query = multilinears
110			.iter()
111			.filter_map(|multilinear| match multilinear {
112				SumcheckMultilinear::Transparent {
113					switchover_round, ..
114				} => Some(switchover_round),
115				_ => None,
116			})
117			.max()
118			.map(|&max_switchover_round| {
119				MultilinearQuery::with_capacity(max_switchover_round.min(n_vars) + 1)
120			});
121
122		Ok(Self {
123			n_vars,
124			evaluation_order,
125			multilinears,
126			nontrivial_evaluation_points,
127			tensor_query,
128			challenges: Vec::new(),
129			last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
130			backend,
131		})
132	}
133
134	pub fn n_multilinears(&self) -> usize {
135		self.multilinears.len()
136	}
137
138	#[instrument(skip_all, name = "ProverState::fold", level = "debug")]
139	pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
140		if self.n_vars == 0 {
141			bail!(Error::ExpectedFinish);
142		}
143
144		// Update the stored multilinear sums.
145		match self.last_coeffs_or_sums {
146			ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
147				let new_sums = round_coeffs
148					.par_iter()
149					.map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
150					.collect();
151				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
152			}
153			ProverStateCoeffsOrSums::Sums(_) => {
154				bail!(Error::ExpectedExecution);
155			}
156		}
157
158		// Update the tensor query.
159		match self.evaluation_order {
160			EvaluationOrder::LowToHigh => self.challenges.push(challenge),
161			EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
162		}
163
164		if let Some(tensor_query) = self.tensor_query.take() {
165			self.tensor_query = match self.evaluation_order {
166				EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
167				// REVIEW: not spending effort to come up with an inplace update method here, as the
168				//         future of switchover is somewhat unclear in light of univariate skip, and
169				//         switchover tensors are small-ish anyway.
170				EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
171			}
172		}
173
174		let any_transparent_left = self.backend.sumcheck_fold_multilinears(
175			self.evaluation_order,
176			self.n_vars,
177			&mut self.multilinears,
178			challenge,
179			self.tensor_query.as_ref().map(Into::into),
180		)?;
181
182		if !any_transparent_left {
183			self.tensor_query = None;
184		}
185
186		self.n_vars -= 1;
187		Ok(())
188	}
189
190	pub fn finish(self) -> Result<Vec<F>, Error> {
191		match self.last_coeffs_or_sums {
192			ProverStateCoeffsOrSums::Coeffs(_) => {
193				bail!(Error::ExpectedFold);
194			}
195			ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
196				0 => {}
197				_ => bail!(Error::ExpectedExecution),
198			},
199		};
200
201		self.multilinears
202			.into_iter()
203			.map(|multilinear| {
204				match multilinear {
205					SumcheckMultilinear::Transparent {
206						multilinear: inner_multilinear,
207						..
208					} => {
209						let tensor_query = self.tensor_query.as_ref().expect(
210							"tensor_query is guaranteed to be Some while there is still a transparent multilinear",
211						);
212						inner_multilinear.evaluate(tensor_query.to_ref())
213					}
214					SumcheckMultilinear::Folded {
215						large_field_folded_evals,
216						suffix_eval,
217					} => Ok(large_field_folded_evals
218						.first()
219						.map_or(suffix_eval, |packed| packed.get(0))
220						.get(0)),
221				}
222				.map_err(Error::MathError)
223			})
224			.collect()
225	}
226
227	/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
228	pub fn calculate_round_evals<Evaluator, Composition>(
229		&self,
230		evaluators: &[Evaluator],
231	) -> Result<Vec<RoundEvals<F>>, Error>
232	where
233		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
234		Composition: CompositionPoly<P>,
235	{
236		let min_const_suffix = evaluators
237			.iter()
238			.map(SumcheckEvaluator::const_eval_suffix)
239			.min()
240			.unwrap_or(1 << self.n_vars);
241		let max_degree = evaluators
242			.iter()
243			.map(|evaluator| evaluator.composition().degree())
244			.max()
245			.unwrap_or(0);
246
247		let _scope = tracing::debug_span!(
248			"calculate_round_evals",
249			n_vars = self.n_vars,
250			n_multilinears = self.multilinears.len(),
251			n_compositions = evaluators.len(),
252			min_const_suffix,
253			max_degree,
254		)
255		.entered();
256
257		Ok(self.backend.sumcheck_compute_round_evals(
258			self.evaluation_order,
259			self.n_vars,
260			self.tensor_query.as_ref().map(Into::into),
261			&self.multilinears,
262			evaluators,
263			&self.nontrivial_evaluation_points,
264		)?)
265	}
266
267	/// Calculate the batched round coefficients from the domain evaluations.
268	///
269	/// This both performs the polynomial interpolation over the evaluations and the mixing with
270	/// the batching coefficient.
271	#[instrument(skip_all, level = "debug")]
272	pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
273		&mut self,
274		interpolators: &[Interpolator],
275		batch_coeff: F,
276		evals: Vec<RoundEvals<F>>,
277	) -> Result<RoundCoeffs<F>, Error> {
278		let coeffs = match self.last_coeffs_or_sums {
279			ProverStateCoeffsOrSums::Coeffs(_) => {
280				bail!(Error::ExpectedFold);
281			}
282			ProverStateCoeffsOrSums::Sums(ref sums) => {
283				if interpolators.len() != sums.len() {
284					bail!(Error::IncorrectNumberOfEvaluators {
285						expected: sums.len(),
286					});
287				}
288
289				let coeffs = izip!(interpolators, sums, evals)
290					.map(|(evaluator, &sum, RoundEvals(evals))| {
291						let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
292						Ok::<_, Error>(RoundCoeffs(coeffs))
293					})
294					.collect::<Result<Vec<_>, _>>()?;
295				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
296				coeffs
297			}
298		};
299
300		let batched_coeffs = coeffs
301			.into_iter()
302			.zip(powers(batch_coeff))
303			.map(|(coeffs, scalar)| coeffs * scalar)
304			.fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
305
306		Ok(batched_coeffs)
307	}
308}