binius_core/protocols/sumcheck/prove/
prover_state.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{util::powers, Field, PackedExtension, PackedField};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6	evaluate_univariate, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15	polynomial::Error as PolynomialError,
16	protocols::sumcheck::{common::RoundCoeffs, error::Error},
17};
18
19pub trait SumcheckInterpolator<F: Field> {
20	/// Given evaluations of the round polynomial, interpolate and return monomial coefficients
21	///
22	/// ## Arguments
23	///
24	/// * `round_evals`: the computed evaluations of the round polynomial
25	fn round_evals_to_coeffs(
26		&self,
27		last_sum: F,
28		round_evals: Vec<F>,
29	) -> Result<Vec<F>, PolynomialError>;
30}
31
32#[derive(Debug)]
33enum ProverStateCoeffsOrSums<F: Field> {
34	Coeffs(Vec<RoundCoeffs<F>>),
35	Sums(Vec<F>),
36}
37
38/// The stored state of a sumcheck prover, which encapsulates common implementation logic.
39///
40/// We expect that CPU sumcheck provers will internally maintain a [`ProverState`] instance and
41/// customize the sumcheck logic through different [`SumcheckEvaluator`] implementations passed to
42/// the common state object.
43#[derive(Debug, CopyGetters)]
44pub struct ProverState<'a, FDomain, P, M, Backend>
45where
46	FDomain: Field,
47	P: PackedField,
48	M: MultilinearPoly<P> + Send + Sync,
49	Backend: ComputationBackend,
50{
51	/// The number of variables in the folded multilinears. This value decrements each round the
52	/// state is folded.
53	#[getset(get_copy = "pub")]
54	n_vars: usize,
55	#[getset(get_copy = "pub")]
56	evaluation_order: EvaluationOrder,
57	multilinears: Vec<SumcheckMultilinear<P, M>>,
58	nontrivial_evaluation_points: Vec<FDomain>,
59	challenges: Vec<P::Scalar>,
60	tensor_query: Option<MultilinearQuery<P>>,
61	last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
62	backend: &'a Backend,
63}
64
65impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
66where
67	FDomain: Field,
68	F: Field,
69	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
70	M: MultilinearPoly<P> + Send + Sync,
71	Backend: ComputationBackend,
72{
73	#[instrument(skip_all, level = "debug", name = "ProverState::new")]
74	pub fn new(
75		evaluation_order: EvaluationOrder,
76		n_vars: usize,
77		multilinears: Vec<SumcheckMultilinear<P, M>>,
78		claimed_sums: Vec<F>,
79		nontrivial_evaluation_points: Vec<FDomain>,
80		backend: &'a Backend,
81	) -> Result<Self, Error> {
82		for multilinear in &multilinears {
83			match *multilinear {
84				SumcheckMultilinear::Transparent {
85					ref multilinear,
86					const_suffix: (_, suffix_len),
87					..
88				} => {
89					if multilinear.n_vars() != n_vars {
90						bail!(Error::NumberOfVariablesMismatch);
91					}
92
93					if suffix_len > 1 << n_vars {
94						bail!(Error::IncorrectConstSuffixes);
95					}
96				}
97
98				SumcheckMultilinear::Folded {
99					large_field_folded_evals: ref evals,
100					..
101				} => {
102					if evals.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
103						bail!(Error::IncorrectConstSuffixes);
104					}
105				}
106			}
107		}
108
109		let tensor_query = multilinears
110			.iter()
111			.filter_map(|multilinear| match multilinear {
112				SumcheckMultilinear::Transparent {
113					switchover_round, ..
114				} => Some(switchover_round),
115				_ => None,
116			})
117			.max()
118			.map(|&max_switchover_round| {
119				MultilinearQuery::with_capacity(max_switchover_round.min(n_vars) + 1)
120			});
121
122		Ok(Self {
123			n_vars,
124			evaluation_order,
125			multilinears,
126			nontrivial_evaluation_points,
127			tensor_query,
128			challenges: Vec::new(),
129			last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
130			backend,
131		})
132	}
133
134	#[instrument(skip_all, name = "ProverState::fold", level = "debug")]
135	pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
136		if self.n_vars == 0 {
137			bail!(Error::ExpectedFinish);
138		}
139
140		// Update the stored multilinear sums.
141		match self.last_coeffs_or_sums {
142			ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
143				let new_sums = round_coeffs
144					.par_iter()
145					.map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
146					.collect();
147				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
148			}
149			ProverStateCoeffsOrSums::Sums(_) => {
150				bail!(Error::ExpectedExecution);
151			}
152		}
153
154		// Update the tensor query.
155		match self.evaluation_order {
156			EvaluationOrder::LowToHigh => self.challenges.push(challenge),
157			EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
158		}
159
160		if let Some(tensor_query) = self.tensor_query.take() {
161			self.tensor_query = match self.evaluation_order {
162				EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
163				// REVIEW: not spending effort to come up with an inplace update method here, as the
164				//         future of switchover is somewhat unclear in light of univariate skip, and
165				//         switchover tensors are small-ish anyway.
166				EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
167			}
168		}
169
170		let any_transparent_left = self.backend.sumcheck_fold_multilinears(
171			self.evaluation_order,
172			self.n_vars,
173			&mut self.multilinears,
174			challenge,
175			self.tensor_query.as_ref().map(Into::into),
176		)?;
177
178		if !any_transparent_left {
179			self.tensor_query = None;
180		}
181
182		self.n_vars -= 1;
183		Ok(())
184	}
185
186	pub fn finish(self) -> Result<Vec<F>, Error> {
187		match self.last_coeffs_or_sums {
188			ProverStateCoeffsOrSums::Coeffs(_) => {
189				bail!(Error::ExpectedFold);
190			}
191			ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
192				0 => {}
193				_ => bail!(Error::ExpectedExecution),
194			},
195		};
196
197		self.multilinears
198			.into_iter()
199			.map(|multilinear| {
200				match multilinear {
201					SumcheckMultilinear::Transparent {
202						multilinear: inner_multilinear,
203						..
204					} => {
205						let tensor_query = self.tensor_query.as_ref()
206							.expect(
207								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
208							);
209						inner_multilinear.evaluate(tensor_query.to_ref())
210					}
211					SumcheckMultilinear::Folded {
212						large_field_folded_evals,
213						suffix_eval,
214					} => Ok(large_field_folded_evals
215						.first()
216						.map_or(suffix_eval, |packed| packed.get(0))
217						.get(0)),
218				}
219				.map_err(Error::MathError)
220			})
221			.collect()
222	}
223
224	/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
225	#[instrument(skip_all, level = "debug")]
226	pub fn calculate_round_evals<Evaluator, Composition>(
227		&self,
228		evaluators: &[Evaluator],
229	) -> Result<Vec<RoundEvals<F>>, Error>
230	where
231		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
232		Composition: CompositionPoly<P>,
233	{
234		Ok(self.backend.sumcheck_compute_round_evals(
235			self.evaluation_order,
236			self.n_vars,
237			self.tensor_query.as_ref().map(Into::into),
238			&self.multilinears,
239			evaluators,
240			&self.nontrivial_evaluation_points,
241		)?)
242	}
243
244	/// Calculate the batched round coefficients from the domain evaluations.
245	///
246	/// This both performs the polynomial interpolation over the evaluations and the mixing with
247	/// the batching coefficient.
248	#[instrument(skip_all, level = "debug")]
249	pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
250		&mut self,
251		interpolators: &[Interpolator],
252		batch_coeff: F,
253		evals: Vec<RoundEvals<F>>,
254	) -> Result<RoundCoeffs<F>, Error> {
255		let coeffs = match self.last_coeffs_or_sums {
256			ProverStateCoeffsOrSums::Coeffs(_) => {
257				bail!(Error::ExpectedFold);
258			}
259			ProverStateCoeffsOrSums::Sums(ref sums) => {
260				if interpolators.len() != sums.len() {
261					bail!(Error::IncorrectNumberOfEvaluators {
262						expected: sums.len(),
263					});
264				}
265
266				let coeffs = izip!(interpolators, sums, evals)
267					.map(|(evaluator, &sum, RoundEvals(evals))| {
268						let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
269						Ok::<_, Error>(RoundCoeffs(coeffs))
270					})
271					.collect::<Result<Vec<_>, _>>()?;
272				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
273				coeffs
274			}
275		};
276
277		let batched_coeffs = coeffs
278			.into_iter()
279			.zip(powers(batch_coeff))
280			.map(|(coeffs, scalar)| coeffs * scalar)
281			.fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
282
283		Ok(batched_coeffs)
284	}
285}