1use std::{marker::PhantomData, ops::Range};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
7use binius_hal::{ComputationBackend, SumcheckEvaluator, SumcheckMultilinear};
8use binius_math::{
9 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
10 MultilinearPoly, RowsBatchRef,
11};
12use binius_maybe_rayon::prelude::*;
13use binius_utils::bail;
14use itertools::izip;
15use stackalloc::stackalloc_with_default;
16use tracing::instrument;
17
18use crate::{
19 polynomial::{Error as PolynomialError, MultilinearComposite},
20 protocols::sumcheck::{
21 common::{
22 CompositeSumClaim, RoundCoeffs, equal_n_vars_check, get_nontrivial_evaluation_points,
23 interpolation_domains_for_composition_degrees,
24 },
25 error::Error,
26 prove::{ProverState, SumcheckInterpolator, SumcheckProver},
27 },
28};
29
30pub fn validate_witness<'a, F, P, M, Composition>(
31 multilinears: &[M],
32 sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
33) -> Result<(), Error>
34where
35 F: Field,
36 P: PackedField<Scalar = F>,
37 M: MultilinearPoly<P> + Send + Sync,
38 Composition: CompositionPoly<P> + 'a,
39{
40 let n_vars = multilinears
41 .first()
42 .map(|multilinear| multilinear.n_vars())
43 .unwrap_or_default();
44 for multilinear in multilinears {
45 if multilinear.n_vars() != n_vars {
46 bail!(Error::NumberOfVariablesMismatch);
47 }
48 }
49
50 let multilinears = multilinears.iter().collect::<Vec<_>>();
51
52 for (i, claim) in sum_claims.into_iter().enumerate() {
53 let CompositeSumClaim {
54 composition,
55 sum: expected_sum,
56 ..
57 } = claim;
58 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
59 let sum = (0..(1 << n_vars))
60 .into_par_iter()
61 .map(|j| witness.evaluate_on_hypercube(j))
62 .try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
63
64 if sum != expected_sum {
65 bail!(Error::SumcheckNaiveValidationFailure {
66 composition_index: i,
67 });
68 }
69 }
70 Ok(())
71}
72
73pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
74where
75 FDomain: Field,
76 P: PackedField,
77 M: MultilinearPoly<P> + Send + Sync,
78 Backend: ComputationBackend,
79{
80 n_vars: usize,
81 state: ProverState<'a, FDomain, P, M, Backend>,
82 compositions: Vec<Composition>,
83 domains: Vec<InterpolationDomain<FDomain>>,
84}
85
86impl<'a, F, FDomain, P, Composition, M, Backend>
87 RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
88where
89 F: Field,
90 FDomain: Field,
91 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
92 Composition: CompositionPoly<P>,
93 M: MultilinearPoly<P> + Send + Sync,
94 Backend: ComputationBackend,
95{
96 #[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
97 pub fn new(
98 evaluation_order: EvaluationOrder,
99 multilinears: Vec<M>,
100 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
101 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
102 switchover_fn: impl Fn(usize) -> usize,
103 backend: &'a Backend,
104 ) -> Result<Self, Error> {
105 let n_vars = equal_n_vars_check(&multilinears)?;
106 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
107
108 #[cfg(feature = "debug_validate_sumcheck")]
109 {
110 let composite_claims = composite_claims
111 .iter()
112 .map(|x| CompositeSumClaim {
113 sum: x.sum,
114 composition: &x.composition,
115 })
116 .collect::<Vec<_>>();
117 validate_witness(&multilinears, composite_claims)?;
118 }
119
120 for claim in &composite_claims {
121 if claim.composition.n_vars() != multilinears.len() {
122 bail!(Error::InvalidComposition {
123 actual: claim.composition.n_vars(),
124 expected: multilinears.len(),
125 });
126 }
127 }
128
129 let claimed_sums = composite_claims
130 .iter()
131 .map(|composite_claim| composite_claim.sum)
132 .collect();
133
134 let domains = interpolation_domains_for_composition_degrees(
135 evaluation_domain_factory,
136 composite_claims
137 .iter()
138 .map(|composite_claim| composite_claim.composition.degree()),
139 )?;
140
141 let compositions = composite_claims
142 .into_iter()
143 .map(|claim| claim.composition)
144 .collect();
145
146 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
147
148 let multilinears = multilinears
149 .into_iter()
150 .map(|multilinear| SumcheckMultilinear::transparent(multilinear, &switchover_fn))
151 .collect();
152
153 let state = ProverState::new(
154 evaluation_order,
155 n_vars,
156 multilinears,
157 claimed_sums,
158 nontrivial_evaluation_points,
159 backend,
160 )?;
161
162 Ok(Self {
163 n_vars,
164 state,
165 compositions,
166 domains,
167 })
168 }
169}
170
171impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
172 for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
173where
174 F: TowerField + ExtensionField<FDomain>,
175 FDomain: Field,
176 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
177 Composition: CompositionPoly<P>,
178 M: MultilinearPoly<P> + Send + Sync,
179 Backend: ComputationBackend,
180{
181 fn n_vars(&self) -> usize {
182 self.n_vars
183 }
184
185 fn evaluation_order(&self) -> EvaluationOrder {
186 self.state.evaluation_order()
187 }
188
189 #[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
190 fn fold(&mut self, challenge: F) -> Result<(), Error> {
191 self.state.fold(challenge)?;
192 Ok(())
193 }
194
195 #[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
196 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
197 let evaluators = izip!(&self.compositions, &self.domains)
198 .map(|(composition, interpolation_domain)| {
199 let composition_at_infinity =
200 ArithCircuitPoly::new(composition.expression().leading_term());
201
202 RegularSumcheckEvaluator {
203 composition,
204 composition_at_infinity,
205 interpolation_domain,
206 _marker: PhantomData,
207 }
208 })
209 .collect::<Vec<_>>();
210
211 let round_evals = self.state.calculate_round_evals(&evaluators)?;
212 self.state
213 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
214 }
215
216 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
217 self.state.finish()
218 }
219}
220
221struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
222where
223 P: PackedField,
224 FDomain: Field,
225{
226 composition: &'a Composition,
227 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
228 interpolation_domain: &'a InterpolationDomain<FDomain>,
229 _marker: PhantomData<P>,
230}
231
232impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
233 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
234where
235 F: TowerField + ExtensionField<FDomain>,
236 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
237 FDomain: Field,
238 Composition: CompositionPoly<P>,
239{
240 fn eval_point_indices(&self) -> Range<usize> {
241 1..self.composition.degree() + 1
244 }
245
246 fn process_subcube_at_eval_point(
247 &self,
248 _subcube_vars: usize,
249 _subcube_index: usize,
250 is_infinity_point: bool,
251 batch_query: &RowsBatchRef<P>,
252 ) -> P {
253 let row_len = batch_query.row_len();
254
255 stackalloc_with_default(row_len, |evals| {
256 if is_infinity_point {
257 self.composition_at_infinity
258 .batch_evaluate(batch_query, evals)
259 .expect("correct by query construction invariant");
260 } else {
261 self.composition
262 .batch_evaluate(batch_query, evals)
263 .expect("correct by query construction invariant");
264 }
265
266 evals.iter().copied().sum()
267 })
268 }
269
270 fn composition(&self) -> &Composition {
271 self.composition
272 }
273
274 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
275 None
276 }
277}
278
279impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
280 for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
281where
282 F: Field,
283 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
284 FDomain: Field,
285{
286 fn round_evals_to_coeffs(
287 &self,
288 last_round_sum: F,
289 mut round_evals: Vec<F>,
290 ) -> Result<Vec<F>, PolynomialError> {
291 round_evals.insert(0, last_round_sum - round_evals[0]);
294
295 if round_evals.len() > 3 {
296 let infinity_round_eval = round_evals.remove(2);
302 round_evals.push(infinity_round_eval);
303 }
304
305 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
306 Ok(coeffs)
307 }
308}