binius_core/protocols/sumcheck/prove/
prover_state.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{util::powers, Field, PackedExtension, PackedField};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6	evaluate_univariate, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15	polynomial::Error as PolynomialError,
16	protocols::sumcheck::{
17		common::{equal_n_vars_check, RoundCoeffs},
18		error::Error,
19	},
20};
21
22pub trait SumcheckInterpolator<F: Field> {
23	/// Given evaluations of the round polynomial, interpolate and return monomial coefficients
24	///
25	/// ## Arguments
26	///
27	/// * `round_evals`: the computed evaluations of the round polynomial
28	fn round_evals_to_coeffs(
29		&self,
30		last_sum: F,
31		round_evals: Vec<F>,
32	) -> Result<Vec<F>, PolynomialError>;
33}
34
35#[derive(Debug)]
36enum ProverStateCoeffsOrSums<F: Field> {
37	Coeffs(Vec<RoundCoeffs<F>>),
38	Sums(Vec<F>),
39}
40
41pub struct MultilinearInput<M> {
42	pub multilinear: M,
43	pub zero_scalars_suffix: usize,
44}
45
46/// The stored state of a sumcheck prover, which encapsulates common implementation logic.
47///
48/// We expect that CPU sumcheck provers will internally maintain a [`ProverState`] instance and
49/// customize the sumcheck logic through different [`SumcheckEvaluator`] implementations passed to
50/// the common state object.
51#[derive(Debug, CopyGetters)]
52pub struct ProverState<'a, FDomain, P, M, Backend>
53where
54	FDomain: Field,
55	P: PackedField,
56	M: MultilinearPoly<P> + Send + Sync,
57	Backend: ComputationBackend,
58{
59	/// The number of variables in the folded multilinears. This value decrements each round the
60	/// state is folded.
61	#[getset(get_copy = "pub")]
62	n_vars: usize,
63	#[getset(get_copy = "pub")]
64	evaluation_order: EvaluationOrder,
65	multilinears: Vec<SumcheckMultilinear<P, M>>,
66	nontrivial_evaluation_points: Vec<FDomain>,
67	challenges: Vec<P::Scalar>,
68	tensor_query: Option<MultilinearQuery<P>>,
69	last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
70	backend: &'a Backend,
71}
72
73impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
74where
75	FDomain: Field,
76	F: Field,
77	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
78	M: MultilinearPoly<P> + Send + Sync,
79	Backend: ComputationBackend,
80{
81	#[instrument(skip_all, level = "debug", name = "ProverState::new")]
82	pub fn new(
83		evaluation_order: EvaluationOrder,
84		multilinears: Vec<MultilinearInput<M>>,
85		claimed_sums: Vec<F>,
86		nontrivial_evaluation_points: Vec<FDomain>,
87		switchover_fn: impl Fn(usize) -> usize,
88		backend: &'a Backend,
89	) -> Result<Self, Error> {
90		let n_vars = equal_n_vars_check(multilinears.iter().map(|input| &input.multilinear))?;
91
92		if multilinears
93			.iter()
94			.any(|input| input.zero_scalars_suffix > 1 << n_vars)
95		{
96			bail!(Error::IncorrectZeroScalarsSuffixes);
97		}
98
99		let switchover_rounds = multilinears
100			.iter()
101			.map(|input| switchover_fn(1 << input.multilinear.log_extension_degree()))
102			.collect::<Vec<_>>();
103		let max_switchover_round = switchover_rounds.iter().copied().max().unwrap_or_default();
104
105		let multilinears = izip!(multilinears, switchover_rounds)
106			.map(|(input, switchover_round)| {
107				let MultilinearInput {
108					multilinear,
109					zero_scalars_suffix,
110				} = input;
111				SumcheckMultilinear::Transparent {
112					multilinear,
113					switchover_round,
114					zero_scalars_suffix,
115				}
116			})
117			.collect::<Vec<_>>();
118
119		let tensor_query = MultilinearQuery::with_capacity(max_switchover_round + 1);
120
121		Ok(Self {
122			n_vars,
123			evaluation_order,
124			multilinears,
125			nontrivial_evaluation_points,
126			challenges: Vec::new(),
127			tensor_query: Some(tensor_query),
128			last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
129			backend,
130		})
131	}
132
133	#[instrument(skip_all, name = "ProverState::fold", level = "debug")]
134	pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
135		if self.n_vars == 0 {
136			bail!(Error::ExpectedFinish);
137		}
138
139		// Update the stored multilinear sums.
140		match self.last_coeffs_or_sums {
141			ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
142				let new_sums = round_coeffs
143					.par_iter()
144					.map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
145					.collect();
146				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
147			}
148			ProverStateCoeffsOrSums::Sums(_) => {
149				bail!(Error::ExpectedExecution);
150			}
151		}
152
153		// Update the tensor query.
154		match self.evaluation_order {
155			EvaluationOrder::LowToHigh => self.challenges.push(challenge),
156			EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
157		}
158
159		if let Some(tensor_query) = self.tensor_query.take() {
160			self.tensor_query = match self.evaluation_order {
161				EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
162				// REVIEW: not spending effort to come up with an inplace update method here, as the
163				//         future of switchover is somewhat unclear in light of univariate skip, and
164				//         switchover tensors are small-ish anyway.
165				EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
166			}
167		}
168
169		let any_transparent_left = self.backend.sumcheck_fold_multilinears(
170			self.evaluation_order,
171			self.n_vars,
172			&mut self.multilinears,
173			challenge,
174			self.tensor_query.as_ref().map(Into::into),
175		)?;
176
177		if !any_transparent_left {
178			self.tensor_query = None;
179		}
180
181		self.n_vars -= 1;
182		Ok(())
183	}
184
185	pub fn finish(self) -> Result<Vec<F>, Error> {
186		match self.last_coeffs_or_sums {
187			ProverStateCoeffsOrSums::Coeffs(_) => {
188				bail!(Error::ExpectedFold);
189			}
190			ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
191				0 => {}
192				_ => bail!(Error::ExpectedExecution),
193			},
194		};
195
196		self.multilinears
197			.into_iter()
198			.map(|multilinear| {
199				match multilinear {
200					SumcheckMultilinear::Transparent {
201						multilinear: inner_multilinear,
202						..
203					} => {
204						let tensor_query = self.tensor_query.as_ref()
205							.expect(
206								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
207							);
208						inner_multilinear.evaluate(tensor_query.to_ref())
209					}
210					SumcheckMultilinear::Folded {
211						large_field_folded_evals,
212					} => Ok(large_field_folded_evals
213						.first()
214						.map_or(F::ZERO, |packed| packed.get(0))
215						.get(0)),
216				}
217				.map_err(Error::MathError)
218			})
219			.collect()
220	}
221
222	/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
223	#[instrument(skip_all, level = "debug")]
224	pub fn calculate_round_evals<Evaluator, Composition>(
225		&self,
226		evaluators: &[Evaluator],
227	) -> Result<Vec<RoundEvals<F>>, Error>
228	where
229		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
230		Composition: CompositionPoly<P>,
231	{
232		Ok(self.backend.sumcheck_compute_round_evals(
233			self.evaluation_order,
234			self.n_vars,
235			self.tensor_query.as_ref().map(Into::into),
236			&self.multilinears,
237			evaluators,
238			&self.nontrivial_evaluation_points,
239		)?)
240	}
241
242	/// Calculate the batched round coefficients from the domain evaluations.
243	///
244	/// This both performs the polynomial interpolation over the evaluations and the mixing with
245	/// the batching coefficient.
246	pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
247		&mut self,
248		interpolators: &[Interpolator],
249		batch_coeff: F,
250		evals: Vec<RoundEvals<F>>,
251	) -> Result<RoundCoeffs<F>, Error> {
252		let coeffs = match self.last_coeffs_or_sums {
253			ProverStateCoeffsOrSums::Coeffs(_) => {
254				bail!(Error::ExpectedFold);
255			}
256			ProverStateCoeffsOrSums::Sums(ref sums) => {
257				if interpolators.len() != sums.len() {
258					bail!(Error::IncorrectNumberOfEvaluators {
259						expected: sums.len(),
260					});
261				}
262
263				let coeffs = izip!(interpolators, sums, evals)
264					.map(|(evaluator, &sum, RoundEvals(evals))| {
265						let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
266						Ok::<_, Error>(RoundCoeffs(coeffs))
267					})
268					.collect::<Result<Vec<_>, _>>()?;
269				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
270				coeffs
271			}
272		};
273
274		let batched_coeffs = coeffs
275			.into_iter()
276			.zip(powers(batch_coeff))
277			.map(|(coeffs, scalar)| coeffs * scalar)
278			.fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
279
280		Ok(batched_coeffs)
281	}
282}