binius_core/protocols/sumcheck/prove/
prover_state.rs1use std::{
4 iter,
5 sync::atomic::{AtomicBool, Ordering},
6};
7
8use binius_field::{util::powers, Field, PackedExtension, PackedField};
9use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
10use binius_math::{
11 evaluate_univariate, fold_left_lerp_inplace, fold_right_lerp, CompositionPoly, EvaluationOrder,
12 MultilinearPoly, MultilinearQuery,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17use getset::CopyGetters;
18use itertools::izip;
19use tracing::instrument;
20
21use crate::{
22 polynomial::Error as PolynomialError,
23 protocols::sumcheck::{
24 common::{determine_switchovers, equal_n_vars_check, RoundCoeffs},
25 error::Error,
26 },
27};
28
29pub trait SumcheckInterpolator<F: Field> {
30 fn round_evals_to_coeffs(
36 &self,
37 last_sum: F,
38 round_evals: Vec<F>,
39 ) -> Result<Vec<F>, PolynomialError>;
40}
41
42#[derive(Debug)]
43enum ProverStateCoeffsOrSums<F: Field> {
44 Coeffs(Vec<RoundCoeffs<F>>),
45 Sums(Vec<F>),
46}
47
48#[derive(Debug, CopyGetters)]
54pub struct ProverState<'a, FDomain, P, M, Backend>
55where
56 FDomain: Field,
57 P: PackedField,
58 M: MultilinearPoly<P> + Send + Sync,
59 Backend: ComputationBackend,
60{
61 #[getset(get_copy = "pub")]
64 n_vars: usize,
65 #[getset(get_copy = "pub")]
66 evaluation_order: EvaluationOrder,
67 multilinears: Vec<SumcheckMultilinear<P, M>>,
68 nontrivial_evaluation_points: Vec<FDomain>,
69 challenges: Vec<P::Scalar>,
70 tensor_query: Option<MultilinearQuery<P>>,
71 last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
72 backend: &'a Backend,
73}
74
75impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
76where
77 FDomain: Field,
78 F: Field,
79 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
80 M: MultilinearPoly<P> + Send + Sync,
81 Backend: ComputationBackend,
82{
83 pub fn new(
84 evaluation_order: EvaluationOrder,
85 multilinears: Vec<M>,
86 claimed_sums: Vec<F>,
87 nontrivial_evaluation_points: Vec<FDomain>,
88 switchover_fn: impl Fn(usize) -> usize,
89 backend: &'a Backend,
90 ) -> Result<Self, Error> {
91 let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
92 Self::new_with_switchover_rounds(
93 evaluation_order,
94 multilinears,
95 &switchover_rounds,
96 claimed_sums,
97 nontrivial_evaluation_points,
98 backend,
99 )
100 }
101
102 #[instrument(
103 skip_all,
104 level = "debug",
105 name = "ProverState::new_with_switchover_rounds"
106 )]
107 pub fn new_with_switchover_rounds(
108 evaluation_order: EvaluationOrder,
109 multilinears: Vec<M>,
110 switchover_rounds: &[usize],
111 claimed_sums: Vec<F>,
112 nontrivial_evaluation_points: Vec<FDomain>,
113 backend: &'a Backend,
114 ) -> Result<Self, Error> {
115 let n_vars = equal_n_vars_check(&multilinears)?;
116
117 if multilinears.len() != switchover_rounds.len() {
118 bail!(Error::MultilinearSwitchoverSizeMismatch);
119 }
120
121 let max_switchover_round = switchover_rounds.iter().copied().max().unwrap_or_default();
122
123 let multilinears = iter::zip(multilinears, switchover_rounds)
124 .map(|(multilinear, &switchover_round)| SumcheckMultilinear::Transparent {
125 multilinear,
126 switchover_round,
127 })
128 .collect();
129
130 let tensor_query = MultilinearQuery::with_capacity(max_switchover_round + 1);
131
132 Ok(Self {
133 n_vars,
134 evaluation_order,
135 multilinears,
136 nontrivial_evaluation_points,
137 challenges: Vec::new(),
138 tensor_query: Some(tensor_query),
139 last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
140 backend,
141 })
142 }
143
144 #[instrument(skip_all, name = "ProverState::fold", level = "debug")]
145 pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
146 if self.n_vars == 0 {
147 bail!(Error::ExpectedFinish);
148 }
149
150 match self.last_coeffs_or_sums {
152 ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
153 let new_sums = round_coeffs
154 .par_iter()
155 .map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
156 .collect();
157 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
158 }
159 ProverStateCoeffsOrSums::Sums(_) => {
160 bail!(Error::ExpectedExecution);
161 }
162 }
163
164 match self.evaluation_order {
166 EvaluationOrder::LowToHigh => self.challenges.push(challenge),
167 EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
168 }
169
170 if let Some(tensor_query) = self.tensor_query.take() {
171 self.tensor_query = match self.evaluation_order {
172 EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
173 EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
177 }
178 }
179
180 let any_transparent_left = AtomicBool::new(false);
184 self.multilinears
185 .par_iter_mut()
186 .try_for_each(|multilinear| {
187 match multilinear {
188 SumcheckMultilinear::Transparent {
189 multilinear: inner_multilinear,
190 ref mut switchover_round,
191 } => {
192 if *switchover_round == 0 {
193 let tensor_query = self.tensor_query.as_ref()
194 .expect(
195 "tensor_query is guaranteed to be Some while there is still a transparent multilinear"
196 );
197
198 let large_field_folded_evals = match self.evaluation_order {
200 EvaluationOrder::LowToHigh => inner_multilinear
201 .evaluate_partial_low(tensor_query.to_ref())?
202 .into_evals(),
203 EvaluationOrder::HighToLow => inner_multilinear
204 .evaluate_partial_high(tensor_query.to_ref())?
205 .into_evals(),
206 };
207
208 *multilinear = SumcheckMultilinear::Folded {
209 large_field_folded_evals,
210 };
211 } else {
212 *switchover_round -= 1;
213 any_transparent_left.store(true, Ordering::Relaxed);
214 }
215 }
216 SumcheckMultilinear::Folded {
217 ref mut large_field_folded_evals,
218 } => {
219 match self.evaluation_order {
222 EvaluationOrder::LowToHigh => {
225 let mut new_large_field_folded_evals =
226 zeroed_vec(1 << self.n_vars.saturating_sub(1 + P::LOG_WIDTH));
227
228 fold_right_lerp(
229 &*large_field_folded_evals,
230 self.n_vars,
231 challenge,
232 &mut new_large_field_folded_evals,
233 )?;
234
235 *large_field_folded_evals = new_large_field_folded_evals;
236 }
237
238 EvaluationOrder::HighToLow => {
240 fold_left_lerp_inplace(
243 large_field_folded_evals,
244 self.n_vars,
245 challenge,
246 )?;
247 }
248 }
249 }
250 };
251 Ok::<(), Error>(())
252 })?;
253
254 if !any_transparent_left.load(Ordering::Relaxed) {
255 self.tensor_query = None;
256 }
257
258 self.n_vars -= 1;
259 Ok(())
260 }
261
262 pub fn finish(self) -> Result<Vec<F>, Error> {
263 match self.last_coeffs_or_sums {
264 ProverStateCoeffsOrSums::Coeffs(_) => {
265 bail!(Error::ExpectedFold);
266 }
267 ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
268 0 => {}
269 _ => bail!(Error::ExpectedExecution),
270 },
271 };
272
273 self.multilinears
274 .into_iter()
275 .map(|multilinear| {
276 match multilinear {
277 SumcheckMultilinear::Transparent {
278 multilinear: inner_multilinear,
279 ..
280 } => {
281 let tensor_query = self.tensor_query.as_ref()
282 .expect(
283 "tensor_query is guaranteed to be Some while there is still a transparent multilinear"
284 );
285 inner_multilinear.evaluate(tensor_query.to_ref())
286 }
287 SumcheckMultilinear::Folded {
288 large_field_folded_evals,
289 } => Ok(large_field_folded_evals
290 .first()
291 .expect("exactly one packed field element left after folding")
292 .get(0)),
293 }
294 .map_err(Error::MathError)
295 })
296 .collect()
297 }
298
299 #[instrument(skip_all, level = "debug")]
301 pub fn calculate_round_evals<Evaluator, Composition>(
302 &self,
303 evaluators: &[Evaluator],
304 ) -> Result<Vec<RoundEvals<F>>, Error>
305 where
306 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
307 Composition: CompositionPoly<P>,
308 {
309 Ok(self.backend.sumcheck_compute_round_evals(
310 self.evaluation_order,
311 self.n_vars,
312 self.tensor_query.as_ref().map(Into::into),
313 &self.multilinears,
314 evaluators,
315 &self.nontrivial_evaluation_points,
316 )?)
317 }
318
319 pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
324 &mut self,
325 interpolators: &[Interpolator],
326 batch_coeff: F,
327 evals: Vec<RoundEvals<F>>,
328 ) -> Result<RoundCoeffs<F>, Error> {
329 let coeffs = match self.last_coeffs_or_sums {
330 ProverStateCoeffsOrSums::Coeffs(_) => {
331 bail!(Error::ExpectedFold);
332 }
333 ProverStateCoeffsOrSums::Sums(ref sums) => {
334 if interpolators.len() != sums.len() {
335 bail!(Error::IncorrectNumberOfEvaluators {
336 expected: sums.len(),
337 });
338 }
339
340 let coeffs = izip!(interpolators, sums, evals)
341 .map(|(evaluator, &sum, RoundEvals(evals))| {
342 let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
343 Ok::<_, Error>(RoundCoeffs(coeffs))
344 })
345 .collect::<Result<Vec<_>, _>>()?;
346 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
347 coeffs
348 }
349 };
350
351 let batched_coeffs = coeffs
352 .into_iter()
353 .zip(powers(batch_coeff))
354 .map(|(coeffs, scalar)| coeffs * scalar)
355 .fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
356
357 Ok(batched_coeffs)
358 }
359}