binius_core/protocols/sumcheck/prove/
prover_state.rs1use binius_field::{util::powers, Field, PackedExtension, PackedField};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6 evaluate_univariate, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15 polynomial::Error as PolynomialError,
16 protocols::sumcheck::{
17 common::{equal_n_vars_check, RoundCoeffs},
18 error::Error,
19 },
20};
21
22pub trait SumcheckInterpolator<F: Field> {
23 fn round_evals_to_coeffs(
29 &self,
30 last_sum: F,
31 round_evals: Vec<F>,
32 ) -> Result<Vec<F>, PolynomialError>;
33}
34
35#[derive(Debug)]
36enum ProverStateCoeffsOrSums<F: Field> {
37 Coeffs(Vec<RoundCoeffs<F>>),
38 Sums(Vec<F>),
39}
40
41pub struct MultilinearInput<M> {
42 pub multilinear: M,
43 pub zero_scalars_suffix: usize,
44}
45
46#[derive(Debug, CopyGetters)]
52pub struct ProverState<'a, FDomain, P, M, Backend>
53where
54 FDomain: Field,
55 P: PackedField,
56 M: MultilinearPoly<P> + Send + Sync,
57 Backend: ComputationBackend,
58{
59 #[getset(get_copy = "pub")]
62 n_vars: usize,
63 #[getset(get_copy = "pub")]
64 evaluation_order: EvaluationOrder,
65 multilinears: Vec<SumcheckMultilinear<P, M>>,
66 nontrivial_evaluation_points: Vec<FDomain>,
67 challenges: Vec<P::Scalar>,
68 tensor_query: Option<MultilinearQuery<P>>,
69 last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
70 backend: &'a Backend,
71}
72
73impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
74where
75 FDomain: Field,
76 F: Field,
77 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
78 M: MultilinearPoly<P> + Send + Sync,
79 Backend: ComputationBackend,
80{
81 #[instrument(skip_all, level = "debug", name = "ProverState::new")]
82 pub fn new(
83 evaluation_order: EvaluationOrder,
84 multilinears: Vec<MultilinearInput<M>>,
85 claimed_sums: Vec<F>,
86 nontrivial_evaluation_points: Vec<FDomain>,
87 switchover_fn: impl Fn(usize) -> usize,
88 backend: &'a Backend,
89 ) -> Result<Self, Error> {
90 let n_vars = equal_n_vars_check(multilinears.iter().map(|input| &input.multilinear))?;
91
92 if multilinears
93 .iter()
94 .any(|input| input.zero_scalars_suffix > 1 << n_vars)
95 {
96 bail!(Error::IncorrectZeroScalarsSuffixes);
97 }
98
99 let switchover_rounds = multilinears
100 .iter()
101 .map(|input| switchover_fn(1 << input.multilinear.log_extension_degree()))
102 .collect::<Vec<_>>();
103 let max_switchover_round = switchover_rounds.iter().copied().max().unwrap_or_default();
104
105 let multilinears = izip!(multilinears, switchover_rounds)
106 .map(|(input, switchover_round)| {
107 let MultilinearInput {
108 multilinear,
109 zero_scalars_suffix,
110 } = input;
111 SumcheckMultilinear::Transparent {
112 multilinear,
113 switchover_round,
114 zero_scalars_suffix,
115 }
116 })
117 .collect::<Vec<_>>();
118
119 let tensor_query = MultilinearQuery::with_capacity(max_switchover_round + 1);
120
121 Ok(Self {
122 n_vars,
123 evaluation_order,
124 multilinears,
125 nontrivial_evaluation_points,
126 challenges: Vec::new(),
127 tensor_query: Some(tensor_query),
128 last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
129 backend,
130 })
131 }
132
133 #[instrument(skip_all, name = "ProverState::fold", level = "debug")]
134 pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
135 if self.n_vars == 0 {
136 bail!(Error::ExpectedFinish);
137 }
138
139 match self.last_coeffs_or_sums {
141 ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
142 let new_sums = round_coeffs
143 .par_iter()
144 .map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
145 .collect();
146 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
147 }
148 ProverStateCoeffsOrSums::Sums(_) => {
149 bail!(Error::ExpectedExecution);
150 }
151 }
152
153 match self.evaluation_order {
155 EvaluationOrder::LowToHigh => self.challenges.push(challenge),
156 EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
157 }
158
159 if let Some(tensor_query) = self.tensor_query.take() {
160 self.tensor_query = match self.evaluation_order {
161 EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
162 EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
166 }
167 }
168
169 let any_transparent_left = self.backend.sumcheck_fold_multilinears(
170 self.evaluation_order,
171 self.n_vars,
172 &mut self.multilinears,
173 challenge,
174 self.tensor_query.as_ref().map(Into::into),
175 )?;
176
177 if !any_transparent_left {
178 self.tensor_query = None;
179 }
180
181 self.n_vars -= 1;
182 Ok(())
183 }
184
185 pub fn finish(self) -> Result<Vec<F>, Error> {
186 match self.last_coeffs_or_sums {
187 ProverStateCoeffsOrSums::Coeffs(_) => {
188 bail!(Error::ExpectedFold);
189 }
190 ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
191 0 => {}
192 _ => bail!(Error::ExpectedExecution),
193 },
194 };
195
196 self.multilinears
197 .into_iter()
198 .map(|multilinear| {
199 match multilinear {
200 SumcheckMultilinear::Transparent {
201 multilinear: inner_multilinear,
202 ..
203 } => {
204 let tensor_query = self.tensor_query.as_ref()
205 .expect(
206 "tensor_query is guaranteed to be Some while there is still a transparent multilinear"
207 );
208 inner_multilinear.evaluate(tensor_query.to_ref())
209 }
210 SumcheckMultilinear::Folded {
211 large_field_folded_evals,
212 } => Ok(large_field_folded_evals
213 .first()
214 .map_or(F::ZERO, |packed| packed.get(0))
215 .get(0)),
216 }
217 .map_err(Error::MathError)
218 })
219 .collect()
220 }
221
222 #[instrument(skip_all, level = "debug")]
224 pub fn calculate_round_evals<Evaluator, Composition>(
225 &self,
226 evaluators: &[Evaluator],
227 ) -> Result<Vec<RoundEvals<F>>, Error>
228 where
229 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
230 Composition: CompositionPoly<P>,
231 {
232 Ok(self.backend.sumcheck_compute_round_evals(
233 self.evaluation_order,
234 self.n_vars,
235 self.tensor_query.as_ref().map(Into::into),
236 &self.multilinears,
237 evaluators,
238 &self.nontrivial_evaluation_points,
239 )?)
240 }
241
242 pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
247 &mut self,
248 interpolators: &[Interpolator],
249 batch_coeff: F,
250 evals: Vec<RoundEvals<F>>,
251 ) -> Result<RoundCoeffs<F>, Error> {
252 let coeffs = match self.last_coeffs_or_sums {
253 ProverStateCoeffsOrSums::Coeffs(_) => {
254 bail!(Error::ExpectedFold);
255 }
256 ProverStateCoeffsOrSums::Sums(ref sums) => {
257 if interpolators.len() != sums.len() {
258 bail!(Error::IncorrectNumberOfEvaluators {
259 expected: sums.len(),
260 });
261 }
262
263 let coeffs = izip!(interpolators, sums, evals)
264 .map(|(evaluator, &sum, RoundEvals(evals))| {
265 let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
266 Ok::<_, Error>(RoundCoeffs(coeffs))
267 })
268 .collect::<Result<Vec<_>, _>>()?;
269 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
270 coeffs
271 }
272 };
273
274 let batched_coeffs = coeffs
275 .into_iter()
276 .zip(powers(batch_coeff))
277 .map(|(coeffs, scalar)| coeffs * scalar)
278 .fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
279
280 Ok(batched_coeffs)
281 }
282}