1use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator, SumcheckMultilinear};
7use binius_math::{
8 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
9 MultilinearPoly, RowsBatchRef,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::bail;
13use itertools::izip;
14use stackalloc::stackalloc_with_default;
15use tracing::instrument;
16
17use crate::{
18 polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19 protocols::sumcheck::{
20 common::{
21 equal_n_vars_check, get_nontrivial_evaluation_points,
22 interpolation_domains_for_composition_degrees, CompositeSumClaim, RoundCoeffs,
23 },
24 error::Error,
25 prove::{ProverState, SumcheckInterpolator, SumcheckProver},
26 },
27};
28
29pub fn validate_witness<'a, F, P, M, Composition>(
30 multilinears: &[M],
31 sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
32) -> Result<(), Error>
33where
34 F: Field,
35 P: PackedField<Scalar = F>,
36 M: MultilinearPoly<P> + Send + Sync,
37 Composition: CompositionPoly<P> + 'a,
38{
39 let n_vars = multilinears
40 .first()
41 .map(|multilinear| multilinear.n_vars())
42 .unwrap_or_default();
43 for multilinear in multilinears {
44 if multilinear.n_vars() != n_vars {
45 bail!(Error::NumberOfVariablesMismatch);
46 }
47 }
48
49 let multilinears = multilinears.iter().collect::<Vec<_>>();
50
51 for (i, claim) in sum_claims.into_iter().enumerate() {
52 let CompositeSumClaim {
53 composition,
54 sum: expected_sum,
55 ..
56 } = claim;
57 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
58 let sum = (0..(1 << n_vars))
59 .into_par_iter()
60 .map(|j| witness.evaluate_on_hypercube(j))
61 .try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
62
63 if sum != expected_sum {
64 bail!(Error::SumcheckNaiveValidationFailure {
65 composition_index: i,
66 });
67 }
68 }
69 Ok(())
70}
71
72pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
73where
74 FDomain: Field,
75 P: PackedField,
76 M: MultilinearPoly<P> + Send + Sync,
77 Backend: ComputationBackend,
78{
79 n_vars: usize,
80 state: ProverState<'a, FDomain, P, M, Backend>,
81 compositions: Vec<Composition>,
82 domains: Vec<InterpolationDomain<FDomain>>,
83}
84
85impl<'a, F, FDomain, P, Composition, M, Backend>
86 RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
87where
88 F: Field,
89 FDomain: Field,
90 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
91 Composition: CompositionPoly<P>,
92 M: MultilinearPoly<P> + Send + Sync,
93 Backend: ComputationBackend,
94{
95 #[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
96 pub fn new(
97 evaluation_order: EvaluationOrder,
98 multilinears: Vec<M>,
99 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
100 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
101 switchover_fn: impl Fn(usize) -> usize,
102 backend: &'a Backend,
103 ) -> Result<Self, Error> {
104 let n_vars = equal_n_vars_check(&multilinears)?;
105 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
106
107 #[cfg(feature = "debug_validate_sumcheck")]
108 {
109 let composite_claims = composite_claims
110 .iter()
111 .map(|x| CompositeSumClaim {
112 sum: x.sum,
113 composition: &x.composition,
114 })
115 .collect::<Vec<_>>();
116 validate_witness(&multilinears, composite_claims)?;
117 }
118
119 for claim in &composite_claims {
120 if claim.composition.n_vars() != multilinears.len() {
121 bail!(Error::InvalidComposition {
122 actual: claim.composition.n_vars(),
123 expected: multilinears.len(),
124 });
125 }
126 }
127
128 let claimed_sums = composite_claims
129 .iter()
130 .map(|composite_claim| composite_claim.sum)
131 .collect();
132
133 let domains = interpolation_domains_for_composition_degrees(
134 evaluation_domain_factory,
135 composite_claims
136 .iter()
137 .map(|composite_claim| composite_claim.composition.degree()),
138 )?;
139
140 let compositions = composite_claims
141 .into_iter()
142 .map(|claim| claim.composition)
143 .collect();
144
145 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
146
147 let multilinears = multilinears
148 .into_iter()
149 .map(|multilinear| SumcheckMultilinear::transparent(multilinear, &switchover_fn))
150 .collect();
151
152 let state = ProverState::new(
153 evaluation_order,
154 n_vars,
155 multilinears,
156 claimed_sums,
157 nontrivial_evaluation_points,
158 backend,
159 )?;
160
161 Ok(Self {
162 n_vars,
163 state,
164 compositions,
165 domains,
166 })
167 }
168}
169
170impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
171 for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
172where
173 F: TowerField + ExtensionField<FDomain>,
174 FDomain: Field,
175 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
176 Composition: CompositionPoly<P>,
177 M: MultilinearPoly<P> + Send + Sync,
178 Backend: ComputationBackend,
179{
180 fn n_vars(&self) -> usize {
181 self.n_vars
182 }
183
184 fn evaluation_order(&self) -> EvaluationOrder {
185 self.state.evaluation_order()
186 }
187
188 #[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
189 fn fold(&mut self, challenge: F) -> Result<(), Error> {
190 self.state.fold(challenge)?;
191 Ok(())
192 }
193
194 #[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
195 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
196 let evaluators = izip!(&self.compositions, &self.domains)
197 .map(|(composition, interpolation_domain)| {
198 let composition_at_infinity =
199 ArithCircuitPoly::new(composition.expression().leading_term());
200
201 RegularSumcheckEvaluator {
202 composition,
203 composition_at_infinity,
204 interpolation_domain,
205 _marker: PhantomData,
206 }
207 })
208 .collect::<Vec<_>>();
209
210 let round_evals = self.state.calculate_round_evals(&evaluators)?;
211 self.state
212 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
213 }
214
215 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
216 self.state.finish()
217 }
218}
219
220struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
221where
222 P: PackedField,
223 FDomain: Field,
224{
225 composition: &'a Composition,
226 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
227 interpolation_domain: &'a InterpolationDomain<FDomain>,
228 _marker: PhantomData<P>,
229}
230
231impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
232 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
233where
234 F: TowerField + ExtensionField<FDomain>,
235 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
236 FDomain: Field,
237 Composition: CompositionPoly<P>,
238{
239 fn eval_point_indices(&self) -> Range<usize> {
240 1..self.composition.degree() + 1
243 }
244
245 fn process_subcube_at_eval_point(
246 &self,
247 _subcube_vars: usize,
248 _subcube_index: usize,
249 is_infinity_point: bool,
250 batch_query: &RowsBatchRef<P>,
251 ) -> P {
252 let row_len = batch_query.row_len();
253
254 stackalloc_with_default(row_len, |evals| {
255 if is_infinity_point {
256 self.composition_at_infinity
257 .batch_evaluate(batch_query, evals)
258 .expect("correct by query construction invariant");
259 } else {
260 self.composition
261 .batch_evaluate(batch_query, evals)
262 .expect("correct by query construction invariant");
263 }
264
265 evals.iter().copied().sum()
266 })
267 }
268
269 fn composition(&self) -> &Composition {
270 self.composition
271 }
272
273 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
274 None
275 }
276}
277
278impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
279 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
280where
281 F: Field,
282 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
283 FDomain: Field,
284{
285 fn round_evals_to_coeffs(
286 &self,
287 last_round_sum: F,
288 mut round_evals: Vec<F>,
289 ) -> Result<Vec<F>, PolynomialError> {
290 round_evals.insert(0, last_round_sum - round_evals[0]);
293
294 if round_evals.len() > 3 {
295 let infinity_round_eval = round_evals.remove(2);
300 round_evals.push(infinity_round_eval);
301 }
302
303 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
304 Ok(coeffs)
305 }
306}