1use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator};
7use binius_math::{
8 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain, MultilinearPoly,
9};
10use binius_maybe_rayon::prelude::*;
11use binius_utils::bail;
12use itertools::izip;
13use stackalloc::stackalloc_with_default;
14use tracing::instrument;
15
16use super::{batch_prove::SumcheckProver, prover_state::ProverState};
17use crate::{
18 polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19 protocols::sumcheck::{
20 common::{get_nontrivial_evaluation_points, CompositeSumClaim, RoundCoeffs},
21 error::Error,
22 prove::prover_state::SumcheckInterpolator,
23 },
24};
25
26pub fn validate_witness<'a, F, P, M, Composition>(
27 multilinears: &[M],
28 sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
29) -> Result<(), Error>
30where
31 F: Field,
32 P: PackedField<Scalar = F>,
33 M: MultilinearPoly<P> + Send + Sync,
34 Composition: CompositionPoly<P> + 'a,
35{
36 let n_vars = multilinears
37 .first()
38 .map(|multilinear| multilinear.n_vars())
39 .unwrap_or_default();
40 for multilinear in multilinears {
41 if multilinear.n_vars() != n_vars {
42 bail!(Error::NumberOfVariablesMismatch);
43 }
44 }
45
46 let multilinears = multilinears.iter().collect::<Vec<_>>();
47
48 for (i, claim) in sum_claims.into_iter().enumerate() {
49 let CompositeSumClaim {
50 composition,
51 sum: expected_sum,
52 ..
53 } = claim;
54 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
55 let sum = (0..(1 << n_vars))
56 .into_par_iter()
57 .map(|j| witness.evaluate_on_hypercube(j))
58 .try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
59
60 if sum != expected_sum {
61 bail!(Error::SumcheckNaiveValidationFailure {
62 composition_index: i,
63 });
64 }
65 }
66 Ok(())
67}
68
69pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
70where
71 FDomain: Field,
72 P: PackedField,
73 M: MultilinearPoly<P> + Send + Sync,
74 Backend: ComputationBackend,
75{
76 n_vars: usize,
77 state: ProverState<'a, FDomain, P, M, Backend>,
78 compositions: Vec<Composition>,
79 domains: Vec<InterpolationDomain<FDomain>>,
80}
81
82impl<'a, F, FDomain, P, Composition, M, Backend>
83 RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
84where
85 F: Field,
86 FDomain: Field,
87 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
88 Composition: CompositionPoly<P>,
89 M: MultilinearPoly<P> + Send + Sync,
90 Backend: ComputationBackend,
91{
92 #[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
93 pub fn new(
94 evaluation_order: EvaluationOrder,
95 multilinears: Vec<M>,
96 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
97 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
98 switchover_fn: impl Fn(usize) -> usize,
99 backend: &'a Backend,
100 ) -> Result<Self, Error> {
101 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
102
103 #[cfg(feature = "debug_validate_sumcheck")]
104 {
105 let composite_claims = composite_claims
106 .iter()
107 .map(|x| CompositeSumClaim {
108 sum: x.sum,
109 composition: &x.composition,
110 })
111 .collect::<Vec<_>>();
112 validate_witness(&multilinears, composite_claims)?;
113 }
114
115 for claim in &composite_claims {
116 if claim.composition.n_vars() != multilinears.len() {
117 bail!(Error::InvalidComposition {
118 actual: claim.composition.n_vars(),
119 expected: multilinears.len(),
120 });
121 }
122 }
123
124 let claimed_sums = composite_claims
125 .iter()
126 .map(|composite_claim| composite_claim.sum)
127 .collect();
128
129 let domains = composite_claims
130 .iter()
131 .map(|composite_claim| {
132 let degree = composite_claim.composition.degree();
133 let domain =
134 evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
135 Ok(domain.into())
136 })
137 .collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
138 .map_err(Error::MathError)?;
139
140 let compositions = composite_claims
141 .into_iter()
142 .map(|claim| claim.composition)
143 .collect();
144
145 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
146
147 let state = ProverState::new(
148 evaluation_order,
149 multilinears,
150 claimed_sums,
151 nontrivial_evaluation_points,
152 switchover_fn,
153 backend,
154 )?;
155 let n_vars = state.n_vars();
156
157 Ok(Self {
158 n_vars,
159 state,
160 compositions,
161 domains,
162 })
163 }
164}
165
166impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
167 for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
168where
169 F: TowerField + ExtensionField<FDomain>,
170 FDomain: Field,
171 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
172 Composition: CompositionPoly<P>,
173 M: MultilinearPoly<P> + Send + Sync,
174 Backend: ComputationBackend,
175{
176 fn n_vars(&self) -> usize {
177 self.n_vars
178 }
179
180 fn evaluation_order(&self) -> EvaluationOrder {
181 self.state.evaluation_order()
182 }
183
184 #[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
185 fn fold(&mut self, challenge: F) -> Result<(), Error> {
186 self.state.fold(challenge)?;
187 Ok(())
188 }
189
190 #[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
191 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
192 let evaluators = izip!(&self.compositions, &self.domains)
193 .map(|(composition, interpolation_domain)| {
194 let composition_at_infinity =
195 ArithCircuitPoly::new(composition.expression().leading_term());
196
197 RegularSumcheckEvaluator {
198 composition,
199 composition_at_infinity,
200 interpolation_domain,
201 _marker: PhantomData,
202 }
203 })
204 .collect::<Vec<_>>();
205
206 let evals = self.state.calculate_round_evals(&evaluators)?;
207 self.state
208 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)
209 }
210
211 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
212 self.state.finish()
213 }
214}
215
216struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
217where
218 P: PackedField,
219 FDomain: Field,
220{
221 composition: &'a Composition,
222 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
223 interpolation_domain: &'a InterpolationDomain<FDomain>,
224 _marker: PhantomData<P>,
225}
226
227impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
228 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
229where
230 F: TowerField + ExtensionField<FDomain>,
231 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
232 FDomain: Field,
233 Composition: CompositionPoly<P>,
234{
235 fn eval_point_indices(&self) -> Range<usize> {
236 1..self.composition.degree() + 1
239 }
240
241 fn process_subcube_at_eval_point(
242 &self,
243 _subcube_vars: usize,
244 _subcube_index: usize,
245 is_infinity_point: bool,
246 batch_query: &[&[P]],
247 ) -> P {
248 let row_len = batch_query.first().map_or(0, |row| row.len());
249
250 stackalloc_with_default(row_len, |evals| {
251 if is_infinity_point {
252 self.composition_at_infinity
253 .batch_evaluate(batch_query, evals)
254 .expect("correct by query construction invariant");
255 } else {
256 self.composition
257 .batch_evaluate(batch_query, evals)
258 .expect("correct by query construction invariant");
259 }
260
261 evals.iter().copied().sum()
262 })
263 }
264
265 fn composition(&self) -> &Composition {
266 self.composition
267 }
268
269 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
270 None
271 }
272}
273
274impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
275 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
276where
277 F: Field,
278 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
279 FDomain: Field,
280{
281 fn round_evals_to_coeffs(
282 &self,
283 last_round_sum: F,
284 mut round_evals: Vec<F>,
285 ) -> Result<Vec<F>, PolynomialError> {
286 round_evals.insert(0, last_round_sum - round_evals[0]);
289
290 if round_evals.len() > 3 {
291 let infinity_round_eval = round_evals.remove(2);
296 round_evals.push(infinity_round_eval);
297 }
298
299 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
300 Ok(coeffs)
301 }
302}