1use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator};
7use binius_math::{
8 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
9 MultilinearPoly, RowsBatchRef,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::bail;
13use itertools::izip;
14use stackalloc::stackalloc_with_default;
15use tracing::instrument;
16
17use crate::{
18 polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19 protocols::sumcheck::{
20 common::{
21 get_nontrivial_evaluation_points, interpolation_domains_for_composition_degrees,
22 CompositeSumClaim, RoundCoeffs,
23 },
24 error::Error,
25 prove::{MultilinearInput, ProverState, SumcheckInterpolator, SumcheckProver},
26 },
27};
28
29pub fn validate_witness<'a, F, P, M, Composition>(
30 multilinears: &[M],
31 sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
32) -> Result<(), Error>
33where
34 F: Field,
35 P: PackedField<Scalar = F>,
36 M: MultilinearPoly<P> + Send + Sync,
37 Composition: CompositionPoly<P> + 'a,
38{
39 let n_vars = multilinears
40 .first()
41 .map(|multilinear| multilinear.n_vars())
42 .unwrap_or_default();
43 for multilinear in multilinears {
44 if multilinear.n_vars() != n_vars {
45 bail!(Error::NumberOfVariablesMismatch);
46 }
47 }
48
49 let multilinears = multilinears.iter().collect::<Vec<_>>();
50
51 for (i, claim) in sum_claims.into_iter().enumerate() {
52 let CompositeSumClaim {
53 composition,
54 sum: expected_sum,
55 ..
56 } = claim;
57 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
58 let sum = (0..(1 << n_vars))
59 .into_par_iter()
60 .map(|j| witness.evaluate_on_hypercube(j))
61 .try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
62
63 if sum != expected_sum {
64 bail!(Error::SumcheckNaiveValidationFailure {
65 composition_index: i,
66 });
67 }
68 }
69 Ok(())
70}
71
72pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
73where
74 FDomain: Field,
75 P: PackedField,
76 M: MultilinearPoly<P> + Send + Sync,
77 Backend: ComputationBackend,
78{
79 n_vars: usize,
80 state: ProverState<'a, FDomain, P, M, Backend>,
81 compositions: Vec<Composition>,
82 domains: Vec<InterpolationDomain<FDomain>>,
83}
84
85impl<'a, F, FDomain, P, Composition, M, Backend>
86 RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
87where
88 F: Field,
89 FDomain: Field,
90 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
91 Composition: CompositionPoly<P>,
92 M: MultilinearPoly<P> + Send + Sync,
93 Backend: ComputationBackend,
94{
95 #[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
96 pub fn new(
97 evaluation_order: EvaluationOrder,
98 multilinears: Vec<M>,
99 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
100 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
101 switchover_fn: impl Fn(usize) -> usize,
102 backend: &'a Backend,
103 ) -> Result<Self, Error> {
104 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
105
106 #[cfg(feature = "debug_validate_sumcheck")]
107 {
108 let composite_claims = composite_claims
109 .iter()
110 .map(|x| CompositeSumClaim {
111 sum: x.sum,
112 composition: &x.composition,
113 })
114 .collect::<Vec<_>>();
115 validate_witness(&multilinears, composite_claims)?;
116 }
117
118 for claim in &composite_claims {
119 if claim.composition.n_vars() != multilinears.len() {
120 bail!(Error::InvalidComposition {
121 actual: claim.composition.n_vars(),
122 expected: multilinears.len(),
123 });
124 }
125 }
126
127 let claimed_sums = composite_claims
128 .iter()
129 .map(|composite_claim| composite_claim.sum)
130 .collect();
131
132 let domains = interpolation_domains_for_composition_degrees(
133 evaluation_domain_factory,
134 composite_claims
135 .iter()
136 .map(|composite_claim| composite_claim.composition.degree()),
137 )?;
138
139 let compositions = composite_claims
140 .into_iter()
141 .map(|claim| claim.composition)
142 .collect();
143
144 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
145
146 let multilinears_input = multilinears
147 .into_iter()
148 .map(|multilinear| MultilinearInput {
149 multilinear,
150 zero_scalars_suffix: 0,
151 })
152 .collect();
153
154 let state = ProverState::new(
155 evaluation_order,
156 multilinears_input,
157 claimed_sums,
158 nontrivial_evaluation_points,
159 switchover_fn,
160 backend,
161 )?;
162 let n_vars = state.n_vars();
163
164 Ok(Self {
165 n_vars,
166 state,
167 compositions,
168 domains,
169 })
170 }
171}
172
173impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
174 for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
175where
176 F: TowerField + ExtensionField<FDomain>,
177 FDomain: Field,
178 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
179 Composition: CompositionPoly<P>,
180 M: MultilinearPoly<P> + Send + Sync,
181 Backend: ComputationBackend,
182{
183 fn n_vars(&self) -> usize {
184 self.n_vars
185 }
186
187 fn evaluation_order(&self) -> EvaluationOrder {
188 self.state.evaluation_order()
189 }
190
191 #[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
192 fn fold(&mut self, challenge: F) -> Result<(), Error> {
193 self.state.fold(challenge)?;
194 Ok(())
195 }
196
197 #[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
198 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
199 let evaluators = izip!(&self.compositions, &self.domains)
200 .map(|(composition, interpolation_domain)| {
201 let composition_at_infinity =
202 ArithCircuitPoly::new(composition.expression().leading_term());
203
204 RegularSumcheckEvaluator {
205 composition,
206 composition_at_infinity,
207 interpolation_domain,
208 _marker: PhantomData,
209 }
210 })
211 .collect::<Vec<_>>();
212
213 let round_evals = self.state.calculate_round_evals(&evaluators)?;
214 self.state
215 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
216 }
217
218 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
219 self.state.finish()
220 }
221}
222
223struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
224where
225 P: PackedField,
226 FDomain: Field,
227{
228 composition: &'a Composition,
229 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
230 interpolation_domain: &'a InterpolationDomain<FDomain>,
231 _marker: PhantomData<P>,
232}
233
234impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
235 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
236where
237 F: TowerField + ExtensionField<FDomain>,
238 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
239 FDomain: Field,
240 Composition: CompositionPoly<P>,
241{
242 fn eval_point_indices(&self) -> Range<usize> {
243 1..self.composition.degree() + 1
246 }
247
248 fn process_subcube_at_eval_point(
249 &self,
250 _subcube_vars: usize,
251 _subcube_index: usize,
252 is_infinity_point: bool,
253 batch_query: &RowsBatchRef<P>,
254 ) -> P {
255 let row_len = batch_query.row_len();
256
257 stackalloc_with_default(row_len, |evals| {
258 if is_infinity_point {
259 self.composition_at_infinity
260 .batch_evaluate(batch_query, evals)
261 .expect("correct by query construction invariant");
262 } else {
263 self.composition
264 .batch_evaluate(batch_query, evals)
265 .expect("correct by query construction invariant");
266 }
267
268 evals.iter().copied().sum()
269 })
270 }
271
272 fn composition(&self) -> &Composition {
273 self.composition
274 }
275
276 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
277 None
278 }
279}
280
281impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
282 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
283where
284 F: Field,
285 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
286 FDomain: Field,
287{
288 fn round_evals_to_coeffs(
289 &self,
290 last_round_sum: F,
291 mut round_evals: Vec<F>,
292 ) -> Result<Vec<F>, PolynomialError> {
293 round_evals.insert(0, last_round_sum - round_evals[0]);
296
297 if round_evals.len() > 3 {
298 let infinity_round_eval = round_evals.remove(2);
303 round_evals.push(infinity_round_eval);
304 }
305
306 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
307 Ok(coeffs)
308 }
309}