binius_core/protocols/sumcheck/prove/
prover_state.rs1use binius_field::{Field, PackedExtension, PackedField, util::powers};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6 CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery, evaluate_univariate,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15 polynomial::Error as PolynomialError,
16 protocols::sumcheck::{common::RoundCoeffs, error::Error},
17};
18
19pub trait SumcheckInterpolator<F: Field> {
20 fn round_evals_to_coeffs(
26 &self,
27 last_sum: F,
28 round_evals: Vec<F>,
29 ) -> Result<Vec<F>, PolynomialError>;
30}
31
32#[derive(Debug)]
33enum ProverStateCoeffsOrSums<F: Field> {
34 Coeffs(Vec<RoundCoeffs<F>>),
35 Sums(Vec<F>),
36}
37
38#[derive(Debug, CopyGetters)]
44pub struct ProverState<'a, FDomain, P, M, Backend>
45where
46 FDomain: Field,
47 P: PackedField,
48 M: MultilinearPoly<P> + Send + Sync,
49 Backend: ComputationBackend,
50{
51 #[getset(get_copy = "pub")]
54 n_vars: usize,
55 #[getset(get_copy = "pub")]
56 evaluation_order: EvaluationOrder,
57 multilinears: Vec<SumcheckMultilinear<P, M>>,
58 nontrivial_evaluation_points: Vec<FDomain>,
59 challenges: Vec<P::Scalar>,
60 tensor_query: Option<MultilinearQuery<P>>,
61 last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
62 backend: &'a Backend,
63}
64
65impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
66where
67 FDomain: Field,
68 F: Field,
69 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
70 M: MultilinearPoly<P> + Send + Sync,
71 Backend: ComputationBackend,
72{
73 #[instrument(skip_all, level = "debug", name = "ProverState::new")]
74 pub fn new(
75 evaluation_order: EvaluationOrder,
76 n_vars: usize,
77 multilinears: Vec<SumcheckMultilinear<P, M>>,
78 claimed_sums: Vec<F>,
79 nontrivial_evaluation_points: Vec<FDomain>,
80 backend: &'a Backend,
81 ) -> Result<Self, Error> {
82 for multilinear in &multilinears {
83 match *multilinear {
84 SumcheckMultilinear::Transparent {
85 ref multilinear,
86 const_suffix: (_, suffix_len),
87 ..
88 } => {
89 if multilinear.n_vars() != n_vars {
90 bail!(Error::NumberOfVariablesMismatch);
91 }
92
93 if suffix_len > 1 << n_vars {
94 bail!(Error::IncorrectConstSuffixes);
95 }
96 }
97
98 SumcheckMultilinear::Folded {
99 large_field_folded_evals: ref evals,
100 ..
101 } => {
102 if evals.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
103 bail!(Error::IncorrectConstSuffixes);
104 }
105 }
106 }
107 }
108
109 let tensor_query = multilinears
110 .iter()
111 .filter_map(|multilinear| match multilinear {
112 SumcheckMultilinear::Transparent {
113 switchover_round, ..
114 } => Some(switchover_round),
115 _ => None,
116 })
117 .max()
118 .map(|&max_switchover_round| {
119 MultilinearQuery::with_capacity(max_switchover_round.min(n_vars) + 1)
120 });
121
122 Ok(Self {
123 n_vars,
124 evaluation_order,
125 multilinears,
126 nontrivial_evaluation_points,
127 tensor_query,
128 challenges: Vec::new(),
129 last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
130 backend,
131 })
132 }
133
134 pub fn n_multilinears(&self) -> usize {
135 self.multilinears.len()
136 }
137
138 #[instrument(skip_all, name = "ProverState::fold", level = "debug")]
139 pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
140 if self.n_vars == 0 {
141 bail!(Error::ExpectedFinish);
142 }
143
144 match self.last_coeffs_or_sums {
146 ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
147 let new_sums = round_coeffs
148 .par_iter()
149 .map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
150 .collect();
151 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
152 }
153 ProverStateCoeffsOrSums::Sums(_) => {
154 bail!(Error::ExpectedExecution);
155 }
156 }
157
158 match self.evaluation_order {
160 EvaluationOrder::LowToHigh => self.challenges.push(challenge),
161 EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
162 }
163
164 if let Some(tensor_query) = self.tensor_query.take() {
165 self.tensor_query = match self.evaluation_order {
166 EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
167 EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
171 }
172 }
173
174 let any_transparent_left = self.backend.sumcheck_fold_multilinears(
175 self.evaluation_order,
176 self.n_vars,
177 &mut self.multilinears,
178 challenge,
179 self.tensor_query.as_ref().map(Into::into),
180 )?;
181
182 if !any_transparent_left {
183 self.tensor_query = None;
184 }
185
186 self.n_vars -= 1;
187 Ok(())
188 }
189
190 pub fn finish(self) -> Result<Vec<F>, Error> {
191 match self.last_coeffs_or_sums {
192 ProverStateCoeffsOrSums::Coeffs(_) => {
193 bail!(Error::ExpectedFold);
194 }
195 ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
196 0 => {}
197 _ => bail!(Error::ExpectedExecution),
198 },
199 };
200
201 self.multilinears
202 .into_iter()
203 .map(|multilinear| {
204 match multilinear {
205 SumcheckMultilinear::Transparent {
206 multilinear: inner_multilinear,
207 ..
208 } => {
209 let tensor_query = self.tensor_query.as_ref().expect(
210 "tensor_query is guaranteed to be Some while there is still a transparent multilinear",
211 );
212 inner_multilinear.evaluate(tensor_query.to_ref())
213 }
214 SumcheckMultilinear::Folded {
215 large_field_folded_evals,
216 suffix_eval,
217 } => Ok(large_field_folded_evals
218 .first()
219 .map_or(suffix_eval, |packed| packed.get(0))
220 .get(0)),
221 }
222 .map_err(Error::MathError)
223 })
224 .collect()
225 }
226
227 pub fn calculate_round_evals<Evaluator, Composition>(
229 &self,
230 evaluators: &[Evaluator],
231 ) -> Result<Vec<RoundEvals<F>>, Error>
232 where
233 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
234 Composition: CompositionPoly<P>,
235 {
236 let min_const_suffix = evaluators
237 .iter()
238 .map(SumcheckEvaluator::const_eval_suffix)
239 .min()
240 .unwrap_or(1 << self.n_vars);
241 let max_degree = evaluators
242 .iter()
243 .map(|evaluator| evaluator.composition().degree())
244 .max()
245 .unwrap_or(0);
246
247 let _scope = tracing::debug_span!(
248 "calculate_round_evals",
249 n_vars = self.n_vars,
250 n_multilinears = self.multilinears.len(),
251 n_compositions = evaluators.len(),
252 min_const_suffix,
253 max_degree,
254 )
255 .entered();
256
257 Ok(self.backend.sumcheck_compute_round_evals(
258 self.evaluation_order,
259 self.n_vars,
260 self.tensor_query.as_ref().map(Into::into),
261 &self.multilinears,
262 evaluators,
263 &self.nontrivial_evaluation_points,
264 )?)
265 }
266
267 #[instrument(skip_all, level = "debug")]
272 pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
273 &mut self,
274 interpolators: &[Interpolator],
275 batch_coeff: F,
276 evals: Vec<RoundEvals<F>>,
277 ) -> Result<RoundCoeffs<F>, Error> {
278 let coeffs = match self.last_coeffs_or_sums {
279 ProverStateCoeffsOrSums::Coeffs(_) => {
280 bail!(Error::ExpectedFold);
281 }
282 ProverStateCoeffsOrSums::Sums(ref sums) => {
283 if interpolators.len() != sums.len() {
284 bail!(Error::IncorrectNumberOfEvaluators {
285 expected: sums.len(),
286 });
287 }
288
289 let coeffs = izip!(interpolators, sums, evals)
290 .map(|(evaluator, &sum, RoundEvals(evals))| {
291 let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
292 Ok::<_, Error>(RoundCoeffs(coeffs))
293 })
294 .collect::<Result<Vec<_>, _>>()?;
295 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
296 coeffs
297 }
298 };
299
300 let batched_coeffs = coeffs
301 .into_iter()
302 .zip(powers(batch_coeff))
303 .map(|(coeffs, scalar)| coeffs * scalar)
304 .fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
305
306 Ok(batched_coeffs)
307 }
308}