binius_core/protocols/sumcheck/prove/
prover_state.rs1use binius_field::{util::powers, Field, PackedExtension, PackedField};
4use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
5use binius_math::{
6 evaluate_univariate, CompositionPoly, EvaluationOrder, MultilinearPoly, MultilinearQuery,
7};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::bail;
10use getset::CopyGetters;
11use itertools::izip;
12use tracing::instrument;
13
14use crate::{
15 polynomial::Error as PolynomialError,
16 protocols::sumcheck::{common::RoundCoeffs, error::Error},
17};
18
19pub trait SumcheckInterpolator<F: Field> {
20 fn round_evals_to_coeffs(
26 &self,
27 last_sum: F,
28 round_evals: Vec<F>,
29 ) -> Result<Vec<F>, PolynomialError>;
30}
31
32#[derive(Debug)]
33enum ProverStateCoeffsOrSums<F: Field> {
34 Coeffs(Vec<RoundCoeffs<F>>),
35 Sums(Vec<F>),
36}
37
38#[derive(Debug, CopyGetters)]
44pub struct ProverState<'a, FDomain, P, M, Backend>
45where
46 FDomain: Field,
47 P: PackedField,
48 M: MultilinearPoly<P> + Send + Sync,
49 Backend: ComputationBackend,
50{
51 #[getset(get_copy = "pub")]
54 n_vars: usize,
55 #[getset(get_copy = "pub")]
56 evaluation_order: EvaluationOrder,
57 multilinears: Vec<SumcheckMultilinear<P, M>>,
58 nontrivial_evaluation_points: Vec<FDomain>,
59 challenges: Vec<P::Scalar>,
60 tensor_query: Option<MultilinearQuery<P>>,
61 last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
62 backend: &'a Backend,
63}
64
65impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
66where
67 FDomain: Field,
68 F: Field,
69 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
70 M: MultilinearPoly<P> + Send + Sync,
71 Backend: ComputationBackend,
72{
73 #[instrument(skip_all, level = "debug", name = "ProverState::new")]
74 pub fn new(
75 evaluation_order: EvaluationOrder,
76 n_vars: usize,
77 multilinears: Vec<SumcheckMultilinear<P, M>>,
78 claimed_sums: Vec<F>,
79 nontrivial_evaluation_points: Vec<FDomain>,
80 backend: &'a Backend,
81 ) -> Result<Self, Error> {
82 for multilinear in &multilinears {
83 match *multilinear {
84 SumcheckMultilinear::Transparent {
85 ref multilinear,
86 const_suffix: (_, suffix_len),
87 ..
88 } => {
89 if multilinear.n_vars() != n_vars {
90 bail!(Error::NumberOfVariablesMismatch);
91 }
92
93 if suffix_len > 1 << n_vars {
94 bail!(Error::IncorrectConstSuffixes);
95 }
96 }
97
98 SumcheckMultilinear::Folded {
99 large_field_folded_evals: ref evals,
100 ..
101 } => {
102 if evals.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
103 bail!(Error::IncorrectConstSuffixes);
104 }
105 }
106 }
107 }
108
109 let tensor_query = multilinears
110 .iter()
111 .filter_map(|multilinear| match multilinear {
112 SumcheckMultilinear::Transparent {
113 switchover_round, ..
114 } => Some(switchover_round),
115 _ => None,
116 })
117 .max()
118 .map(|&max_switchover_round| {
119 MultilinearQuery::with_capacity(max_switchover_round.min(n_vars) + 1)
120 });
121
122 Ok(Self {
123 n_vars,
124 evaluation_order,
125 multilinears,
126 nontrivial_evaluation_points,
127 tensor_query,
128 challenges: Vec::new(),
129 last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
130 backend,
131 })
132 }
133
134 #[instrument(skip_all, name = "ProverState::fold", level = "debug")]
135 pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
136 if self.n_vars == 0 {
137 bail!(Error::ExpectedFinish);
138 }
139
140 match self.last_coeffs_or_sums {
142 ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
143 let new_sums = round_coeffs
144 .par_iter()
145 .map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
146 .collect();
147 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
148 }
149 ProverStateCoeffsOrSums::Sums(_) => {
150 bail!(Error::ExpectedExecution);
151 }
152 }
153
154 match self.evaluation_order {
156 EvaluationOrder::LowToHigh => self.challenges.push(challenge),
157 EvaluationOrder::HighToLow => self.challenges.insert(0, challenge),
158 }
159
160 if let Some(tensor_query) = self.tensor_query.take() {
161 self.tensor_query = match self.evaluation_order {
162 EvaluationOrder::LowToHigh => Some(tensor_query.update(&[challenge])?),
163 EvaluationOrder::HighToLow => Some(MultilinearQuery::expand(&self.challenges)),
167 }
168 }
169
170 let any_transparent_left = self.backend.sumcheck_fold_multilinears(
171 self.evaluation_order,
172 self.n_vars,
173 &mut self.multilinears,
174 challenge,
175 self.tensor_query.as_ref().map(Into::into),
176 )?;
177
178 if !any_transparent_left {
179 self.tensor_query = None;
180 }
181
182 self.n_vars -= 1;
183 Ok(())
184 }
185
186 pub fn finish(self) -> Result<Vec<F>, Error> {
187 match self.last_coeffs_or_sums {
188 ProverStateCoeffsOrSums::Coeffs(_) => {
189 bail!(Error::ExpectedFold);
190 }
191 ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
192 0 => {}
193 _ => bail!(Error::ExpectedExecution),
194 },
195 };
196
197 self.multilinears
198 .into_iter()
199 .map(|multilinear| {
200 match multilinear {
201 SumcheckMultilinear::Transparent {
202 multilinear: inner_multilinear,
203 ..
204 } => {
205 let tensor_query = self.tensor_query.as_ref()
206 .expect(
207 "tensor_query is guaranteed to be Some while there is still a transparent multilinear"
208 );
209 inner_multilinear.evaluate(tensor_query.to_ref())
210 }
211 SumcheckMultilinear::Folded {
212 large_field_folded_evals,
213 suffix_eval,
214 } => Ok(large_field_folded_evals
215 .first()
216 .map_or(suffix_eval, |packed| packed.get(0))
217 .get(0)),
218 }
219 .map_err(Error::MathError)
220 })
221 .collect()
222 }
223
224 #[instrument(skip_all, level = "debug")]
226 pub fn calculate_round_evals<Evaluator, Composition>(
227 &self,
228 evaluators: &[Evaluator],
229 ) -> Result<Vec<RoundEvals<F>>, Error>
230 where
231 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
232 Composition: CompositionPoly<P>,
233 {
234 Ok(self.backend.sumcheck_compute_round_evals(
235 self.evaluation_order,
236 self.n_vars,
237 self.tensor_query.as_ref().map(Into::into),
238 &self.multilinears,
239 evaluators,
240 &self.nontrivial_evaluation_points,
241 )?)
242 }
243
244 #[instrument(skip_all, level = "debug")]
249 pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
250 &mut self,
251 interpolators: &[Interpolator],
252 batch_coeff: F,
253 evals: Vec<RoundEvals<F>>,
254 ) -> Result<RoundCoeffs<F>, Error> {
255 let coeffs = match self.last_coeffs_or_sums {
256 ProverStateCoeffsOrSums::Coeffs(_) => {
257 bail!(Error::ExpectedFold);
258 }
259 ProverStateCoeffsOrSums::Sums(ref sums) => {
260 if interpolators.len() != sums.len() {
261 bail!(Error::IncorrectNumberOfEvaluators {
262 expected: sums.len(),
263 });
264 }
265
266 let coeffs = izip!(interpolators, sums, evals)
267 .map(|(evaluator, &sum, RoundEvals(evals))| {
268 let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
269 Ok::<_, Error>(RoundCoeffs(coeffs))
270 })
271 .collect::<Result<Vec<_>, _>>()?;
272 self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
273 coeffs
274 }
275 };
276
277 let batched_coeffs = coeffs
278 .into_iter()
279 .zip(powers(batch_coeff))
280 .map(|(coeffs, scalar)| coeffs * scalar)
281 .fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);
282
283 Ok(batched_coeffs)
284 }
285}