1use std::{marker::PhantomData, sync::Arc};
4
5use binius_field::{
6 util::powers, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
7 PackedSubfield, TowerField,
8};
9use binius_hal::ComputationBackend;
10use binius_math::{
11 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearPoly,
12 MultilinearQuery,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17use getset::Getters;
18use itertools::izip;
19use tracing::instrument;
20
21use crate::{
22 polynomial::MultilinearComposite,
23 protocols::sumcheck::{
24 common::{equal_n_vars_check, CompositeSumClaim},
25 prove::{
26 eq_ind::EqIndSumcheckProverBuilder,
27 univariate::{
28 zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
29 ZerocheckUnivariateFoldResult,
30 },
31 SumcheckProver, UnivariateZerocheckProver,
32 },
33 univariate::LagrangeRoundEvals,
34 univariate_zerocheck::domain_size,
35 Error,
36 },
37 witness::MultilinearWitness,
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41 multilinears: &[M],
42 zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45 F: Field,
46 P: PackedField<Scalar = F>,
47 M: MultilinearPoly<P> + Send + Sync,
48 Composition: CompositionPoly<P> + 'a,
49{
50 let n_vars = multilinears
51 .first()
52 .map(|multilinear| multilinear.n_vars())
53 .unwrap_or_default();
54 for multilinear in multilinears {
55 if multilinear.n_vars() != n_vars {
56 bail!(Error::NumberOfVariablesMismatch);
57 }
58 }
59
60 let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62 for (name, composition) in zero_claims {
63 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64 (0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65 if witness.evaluate_on_hypercube(j)? != F::ZERO {
66 return Err(Error::ZerocheckNaiveValidationFailure {
67 composition_name: name.to_string(),
68 vertex_index: j,
69 });
70 }
71 Ok(())
72 })?;
73 }
74 Ok(())
75}
76
77#[derive(Debug, Getters)]
89pub struct UnivariateZerocheck<
90 'a,
91 FDomain,
92 FBase,
93 P,
94 CompositionBase,
95 Composition,
96 M,
97 DomainFactory,
98 SwitchoverFn,
99 Backend,
100> where
101 FDomain: Field,
102 FBase: Field,
103 P: PackedField,
104 Backend: ComputationBackend,
105{
106 n_vars: usize,
107 #[getset(get = "pub")]
108 multilinears: Vec<M>,
109 compositions: Vec<(String, CompositionBase, Composition)>,
110 zerocheck_challenges: Vec<P::Scalar>,
111 domain_factory: DomainFactory,
112 switchover_fn: SwitchoverFn,
113 backend: &'a Backend,
114 univariate_evals_output: Option<ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>>,
115 _p_base_marker: PhantomData<FBase>,
116 _fdomain_marker: PhantomData<FDomain>,
117}
118
119impl<
120 'a,
121 F,
122 FDomain,
123 FBase,
124 P,
125 CompositionBase,
126 Composition,
127 M,
128 DomainFactory,
129 SwitchoverFn,
130 Backend,
131 >
132 UnivariateZerocheck<
133 'a,
134 FDomain,
135 FBase,
136 P,
137 CompositionBase,
138 Composition,
139 M,
140 DomainFactory,
141 SwitchoverFn,
142 Backend,
143 >
144where
145 F: TowerField,
146 FDomain: Field,
147 FBase: ExtensionField<FDomain>,
148 P: PackedFieldIndexable<Scalar = F>
149 + PackedExtension<F, PackedSubfield = P>
150 + PackedExtension<FBase>
151 + PackedExtension<FDomain>,
152 CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
153 Composition: CompositionPoly<P> + 'a,
154 M: MultilinearPoly<P> + Send + Sync + 'a,
155 DomainFactory: EvaluationDomainFactory<FDomain>,
156 SwitchoverFn: Fn(usize) -> usize,
157 Backend: ComputationBackend,
158{
159 pub fn new(
160 multilinears: Vec<M>,
161 zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
162 zerocheck_challenges: &[F],
163 domain_factory: DomainFactory,
164 switchover_fn: SwitchoverFn,
165 backend: &'a Backend,
166 ) -> Result<Self, Error> {
167 let n_vars = equal_n_vars_check(&multilinears)?;
168
169 let compositions = zero_claims.into_iter().collect::<Vec<_>>();
170 for (_, composition_base, composition) in &compositions {
171 if composition_base.n_vars() != multilinears.len()
172 || composition.n_vars() != multilinears.len()
173 || composition_base.degree() != composition.degree()
174 {
175 bail!(Error::InvalidComposition {
176 actual: composition.n_vars(),
177 expected: multilinears.len(),
178 });
179 }
180 }
181 #[cfg(feature = "debug_validate_sumcheck")]
182 {
183 let compositions = compositions
184 .iter()
185 .map(|(name, _, a)| (name.clone(), a))
186 .collect::<Vec<_>>();
187 validate_witness(&multilinears, &compositions)?;
188 }
189
190 let zerocheck_challenges = zerocheck_challenges.to_vec();
191
192 Ok(Self {
193 n_vars,
194 multilinears,
195 compositions,
196 zerocheck_challenges,
197 domain_factory,
198 switchover_fn,
199 backend,
200 univariate_evals_output: None,
201 _p_base_marker: PhantomData,
202 _fdomain_marker: PhantomData,
203 })
204 }
205
206 #[instrument(skip_all, level = "debug")]
207 #[allow(clippy::type_complexity)]
208 pub fn into_regular_zerocheck(self) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
209 if self.univariate_evals_output.is_some() {
210 bail!(Error::ExpectedFold);
211 }
212
213 #[cfg(feature = "debug_validate_sumcheck")]
214 {
215 let compositions = self
216 .compositions
217 .iter()
218 .map(|(name, _, a)| (name.clone(), a))
219 .collect::<Vec<_>>();
220 validate_witness(&self.multilinears, &compositions)?;
221 }
222
223 let composite_claims = self
224 .compositions
225 .into_iter()
226 .map(|(_, _, composition)| CompositeSumClaim {
227 composition,
228 sum: F::ZERO,
229 })
230 .collect::<Vec<_>>();
231
232 let first_round_eval_1s = composite_claims.iter().map(|_| F::ZERO).collect::<Vec<_>>();
233
234 let prover = EqIndSumcheckProverBuilder::new(self.backend)
235 .with_first_round_eval_1s(&first_round_eval_1s)
236 .build(
237 EvaluationOrder::LowToHigh,
238 self.multilinears,
239 &self.zerocheck_challenges,
240 composite_claims,
241 self.domain_factory,
242 self.switchover_fn,
243 )?;
244
245 Ok(Box::new(prover) as Box<dyn SumcheckProver<F> + 'a>)
246 }
247}
248
249impl<
250 'a,
251 F,
252 FDomain,
253 FBase,
254 P,
255 CompositionBase,
256 Composition,
257 M,
258 InterpolationDomainFactory,
259 SwitchoverFn,
260 Backend,
261 > UnivariateZerocheckProver<'a, F>
262 for UnivariateZerocheck<
263 'a,
264 FDomain,
265 FBase,
266 P,
267 CompositionBase,
268 Composition,
269 M,
270 InterpolationDomainFactory,
271 SwitchoverFn,
272 Backend,
273 >
274where
275 F: TowerField,
276 FDomain: TowerField,
277 FBase: ExtensionField<FDomain>,
278 P: PackedFieldIndexable<Scalar = F>
279 + PackedExtension<F, PackedSubfield = P>
280 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
281 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
282 CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
283 Composition: CompositionPoly<P> + 'static,
284 M: MultilinearPoly<P> + Send + Sync + 'a,
285 InterpolationDomainFactory: EvaluationDomainFactory<FDomain>,
286 SwitchoverFn: Fn(usize) -> usize,
287 Backend: ComputationBackend,
288{
289 fn n_vars(&self) -> usize {
290 self.n_vars
291 }
292
293 fn domain_size(&self, skip_rounds: usize) -> usize {
294 self.compositions
295 .iter()
296 .map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
297 .max()
298 .unwrap_or(0)
299 }
300
301 #[instrument(skip_all, level = "debug")]
302 fn execute_univariate_round(
303 &mut self,
304 skip_rounds: usize,
305 max_domain_size: usize,
306 batch_coeff: F,
307 ) -> Result<LagrangeRoundEvals<F>, Error> {
308 if self.univariate_evals_output.is_some() {
309 bail!(Error::ExpectedFold);
310 }
311
312 let compositions_base = self
314 .compositions
315 .iter()
316 .map(|(_, composition_base, _)| composition_base)
317 .collect::<Vec<_>>();
318
319 let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
322 &self.multilinears,
323 &compositions_base,
324 &self.zerocheck_challenges,
325 skip_rounds,
326 max_domain_size,
327 self.backend,
328 )?;
329
330 let zeros_prefix_len = 1 << skip_rounds;
332 let batched_round_evals = univariate_evals_output
333 .round_evals
334 .iter()
335 .zip(powers(batch_coeff))
336 .map(|(evals, scalar)| {
337 let round_evals = LagrangeRoundEvals {
338 zeros_prefix_len,
339 evals: evals.clone(),
340 };
341 round_evals * scalar
342 })
343 .try_fold(
344 LagrangeRoundEvals::zeros(max_domain_size),
345 |mut accum, evals| -> Result<_, Error> {
346 accum.add_assign_lagrange(&evals)?;
347 Ok(accum)
348 },
349 )?;
350
351 self.univariate_evals_output = Some(univariate_evals_output);
352
353 Ok(batched_round_evals)
354 }
355
356 #[instrument(skip_all, level = "debug")]
357 fn fold_univariate_round(
358 self: Box<Self>,
359 challenge: F,
360 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
361 if self.univariate_evals_output.is_none() {
362 bail!(Error::ExpectedExecution);
363 }
364
365 let ZerocheckUnivariateFoldResult {
368 skip_rounds,
369 subcube_lagrange_coeffs,
370 claimed_sums,
371 partial_eq_ind_evals,
372 } = self
373 .univariate_evals_output
374 .expect("validated to be Some")
375 .fold::<FDomain>(challenge)?;
376
377 let mut packed_subcube_lagrange_coeffs =
386 zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
387 P::unpack_scalars_mut(&mut packed_subcube_lagrange_coeffs)[..1 << skip_rounds]
388 .copy_from_slice(&subcube_lagrange_coeffs);
389 let lagrange_coeffs_query =
390 MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
391
392 let partial_low_multilinears = self
393 .multilinears
394 .into_par_iter()
395 .map(|multilinear| -> Result<_, Error> {
396 let multilinear =
397 multilinear.evaluate_partial_low(lagrange_coeffs_query.to_ref())?;
398 let mle_adapter = Arc::new(MLEDirectAdapter::from(multilinear));
399 Ok(mle_adapter as MultilinearWitness<'static, P>)
400 })
401 .collect::<Result<Vec<_>, _>>()?;
402
403 let composite_claims = izip!(self.compositions, claimed_sums)
404 .map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
405 .collect::<Vec<_>>();
406
407 let regular_prover = EqIndSumcheckProverBuilder::new(self.backend)
410 .with_eq_ind_partial_evals(partial_eq_ind_evals)
411 .build(
412 EvaluationOrder::LowToHigh,
413 partial_low_multilinears,
414 &self.zerocheck_challenges,
415 composite_claims,
416 self.domain_factory,
417 |extension_degree| {
418 (self.switchover_fn)(extension_degree).saturating_sub(skip_rounds)
419 },
420 )?;
421
422 Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
423 }
424}