1use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackend;
6use binius_math::{EvaluationDomainFactory, EvaluationOrder, MultilinearPoly};
7use binius_utils::bail;
8
9use super::{
10 RegularSumcheckProver, ZerocheckProverImpl,
11 eq_ind::{EqIndSumcheckProver, EqIndSumcheckProverBuilder},
12};
13use crate::{
14 oracle::{Constraint, ConstraintPredicate, SizedConstraintSet},
15 protocols::{
16 evalcheck::{EvalPoint, subclaims::MemoizedData},
17 sumcheck::{
18 CompositeSumClaim, Error, OracleClaimMeta, constraint_set_mlecheck_claim,
19 constraint_set_sumcheck_claim,
20 },
21 },
22 witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
23};
24
25pub type OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend> = ZerocheckProverImpl<
26 'a,
27 FDomain,
28 FBase,
29 P,
30 ArithCircuitPoly<FBase>,
31 ArithCircuitPoly<<P as PackedField>::Scalar>,
32 MultilinearWitness<'a, P>,
33 DomainFactory,
34 Backend,
35>;
36
37pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
38 'a,
39 FDomain,
40 P,
41 ArithCircuitPoly<<P as PackedField>::Scalar>,
42 MultilinearWitness<'a, P>,
43 Backend,
44>;
45
46pub type OracleMLECheckProver<'a, FDomain, P, Backend> = EqIndSumcheckProver<
47 'a,
48 FDomain,
49 P,
50 ArithCircuitPoly<<P as PackedField>::Scalar>,
51 MultilinearWitness<'a, P>,
52 Backend,
53>;
54
55pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, DomainFactory, Backend>(
58 constraints: Vec<Constraint<P::Scalar>>,
59 multilinears: Vec<MultilinearWitness<'a, P>>,
60 domain_factory: DomainFactory,
61 zerocheck_challenges: &[F],
62 backend: &'a Backend,
63) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend>, Error>
64where
65 P: PackedField<Scalar = F>
66 + PackedExtension<F, PackedSubfield = P>
67 + PackedExtension<FDomain>
68 + PackedExtension<FBase>,
69 F: TowerField,
70 FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
71 FDomain: Field,
72 DomainFactory: EvaluationDomainFactory<FDomain>,
73 Backend: ComputationBackend,
74{
75 let mut zeros = Vec::with_capacity(constraints.len());
76
77 for Constraint {
78 composition,
79 predicate,
80 name,
81 } in constraints
82 {
83 let composition_base = composition
84 .try_convert_field::<FBase>()
85 .map_err(|_| Error::CircuitFieldDowncastFailed)?;
86 match predicate {
87 ConstraintPredicate::Zero => {
88 zeros.push((
89 name,
90 ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
91 ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
92 ));
93 }
94 _ => bail!(Error::MixedBatchingNotSupported),
95 }
96 }
97
98 let prover = OracleZerocheckProver::<_, _, FDomain, _, _>::new(
99 multilinears,
100 zeros,
101 zerocheck_challenges,
102 domain_factory,
103 backend,
104 )?;
105
106 Ok(prover)
107}
108
109pub fn constraint_set_sumcheck_prover<'a, F, P, FDomain, Backend>(
112 evaluation_order: EvaluationOrder,
113 constraint_set: SizedConstraintSet<F>,
114 witness: &MultilinearExtensionIndex<'a, P>,
115 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
116 switchover_fn: impl Fn(usize) -> usize + Clone,
117 backend: &'a Backend,
118) -> Result<OracleSumcheckProver<'a, FDomain, P, Backend>, Error>
119where
120 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
121 F: TowerField + ExtensionField<FDomain>,
122 FDomain: Field,
123 Backend: ComputationBackend,
124{
125 let (constraints, multilinears) = split_constraint_set::<F, P>(constraint_set, witness)?;
126
127 let mut sums = Vec::new();
128
129 for Constraint {
130 composition,
131 predicate,
132 ..
133 } in constraints
134 {
135 match predicate {
136 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
137 composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
138 sum,
139 }),
140 _ => bail!(Error::MixedBatchingNotSupported),
141 }
142 }
143
144 let prover = RegularSumcheckProver::new(
145 evaluation_order,
146 multilinears,
147 sums,
148 evaluation_domain_factory,
149 switchover_fn,
150 backend,
151 )?;
152
153 Ok(prover)
154}
155
156#[allow(clippy::too_many_arguments)]
159pub fn constraint_set_mlecheck_prover<'a, 'b, F, P, FDomain, Backend>(
160 evaluation_order: EvaluationOrder,
161 constraint_set: SizedConstraintSet<F>,
162 eq_ind_challenges: &[F],
163 memoized_data: &mut MemoizedData<'b, P>,
164 witness: &MultilinearExtensionIndex<'a, P>,
165 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
166 switchover_fn: impl Fn(usize) -> usize + Clone,
167 backend: &'a Backend,
168) -> Result<OracleMLECheckProver<'a, FDomain, P, Backend>, Error>
169where
170 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
171 F: TowerField + ExtensionField<FDomain>,
172 FDomain: Field,
173 Backend: ComputationBackend,
174{
175 let SizedConstraintSet {
176 oracle_ids,
177 constraints,
178 n_vars,
179 } = constraint_set;
180
181 let mut multilinears = Vec::with_capacity(oracle_ids.len());
182 let mut const_suffixes = Vec::with_capacity(oracle_ids.len());
183
184 for id in oracle_ids {
185 let IndexEntry {
186 multilin_poly,
187 nonzero_scalars_prefix,
188 } = witness.get_index_entry(id)?;
189
190 if multilin_poly.n_vars() != n_vars {
191 bail!(Error::ConstraintSetNumberOfVariablesMismatch);
192 }
193
194 multilinears.push(multilin_poly);
195 const_suffixes.push((F::ZERO, ((1 << n_vars) - nonzero_scalars_prefix)))
196 }
197
198 let mut sums = Vec::new();
199
200 for Constraint {
201 composition,
202 predicate,
203 ..
204 } in constraints
205 {
206 match predicate {
207 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
208 composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
209 sum,
210 }),
211 _ => bail!(Error::MixedBatchingNotSupported),
212 }
213 }
214
215 let n_vars = eq_ind_challenges.len();
216
217 let eq_ind_partial = match evaluation_order {
218 EvaluationOrder::LowToHigh => &eq_ind_challenges[n_vars.min(1)..],
219 EvaluationOrder::HighToLow => &eq_ind_challenges[..n_vars.saturating_sub(1)],
220 };
221
222 let eq_ind_partial_evals = memoized_data
223 .full_query(eq_ind_partial)?
224 .expansion()
225 .to_vec();
226
227 let prover = EqIndSumcheckProverBuilder::with_switchover(multilinears, switchover_fn, backend)?
228 .with_eq_ind_partial_evals(Backend::to_hal_slice(eq_ind_partial_evals))
229 .with_const_suffixes(&const_suffixes)?
230 .build(evaluation_order, eq_ind_challenges, sums, evaluation_domain_factory)?;
231
232 Ok(prover)
233}
234
235type ConstraintsAndMultilinears<'a, F, P> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, P>>);
236
237#[allow(clippy::type_complexity)]
238pub fn split_constraint_set<'a, F, P>(
239 constraint_set: SizedConstraintSet<F>,
240 witness: &MultilinearExtensionIndex<'a, P>,
241) -> Result<ConstraintsAndMultilinears<'a, F, P>, Error>
242where
243 F: Field,
244 P: PackedField,
245 P::Scalar: ExtensionField<F>,
246{
247 let SizedConstraintSet {
248 oracle_ids,
249 constraints,
250 n_vars,
251 } = constraint_set;
252
253 let multilinears = oracle_ids
254 .iter()
255 .map(|&oracle_id| witness.get_multilin_poly(oracle_id))
256 .collect::<Result<Vec<_>, _>>()?;
257
258 if multilinears
259 .iter()
260 .any(|multilin| multilin.n_vars() != n_vars)
261 {
262 bail!(Error::ConstraintSetNumberOfVariablesMismatch);
263 }
264
265 Ok((constraints, multilinears))
266}
267
268pub struct SumcheckProversWithMetas<'a, P, FDomain, Backend>
269where
270 P: PackedField,
271 FDomain: Field,
272 Backend: ComputationBackend,
273{
274 pub provers: Vec<OracleSumcheckProver<'a, FDomain, P, Backend>>,
275 pub metas: Vec<OracleClaimMeta>,
276}
277
278pub fn constraint_sets_sumcheck_provers_metas<'a, P, FDomain, Backend>(
280 evaluation_order: EvaluationOrder,
281 constraint_sets: Vec<SizedConstraintSet<P::Scalar>>,
282 witness: &MultilinearExtensionIndex<'a, P>,
283 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
284 switchover_fn: impl Fn(usize) -> usize,
285 backend: &'a Backend,
286) -> Result<SumcheckProversWithMetas<'a, P, FDomain, Backend>, Error>
287where
288 P: PackedExtension<FDomain>,
289 P::Scalar: TowerField + ExtensionField<FDomain>,
290 FDomain: Field,
291 Backend: ComputationBackend,
292{
293 let mut provers = Vec::with_capacity(constraint_sets.len());
294 let mut metas = Vec::with_capacity(constraint_sets.len());
295
296 for constraint_set in constraint_sets {
297 let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
298 let prover = constraint_set_sumcheck_prover(
299 evaluation_order,
300 constraint_set,
301 witness,
302 evaluation_domain_factory.clone(),
303 &switchover_fn,
304 backend,
305 )?;
306 metas.push(meta);
307 provers.push(prover);
308 }
309 Ok(SumcheckProversWithMetas { provers, metas })
310}
311
312pub struct MLECheckProverWithMeta<'a, P, FDomain, Backend>
313where
314 P: PackedField,
315 FDomain: Field,
316 Backend: ComputationBackend,
317{
318 pub prover: OracleMLECheckProver<'a, FDomain, P, Backend>,
319 pub meta: OracleClaimMeta,
320}
321
322#[allow(clippy::too_many_arguments)]
324pub fn constraint_sets_mlecheck_prover_meta<'a, 'b, P, FDomain, Backend>(
325 evaluation_order: EvaluationOrder,
326 constraint_set: SizedConstraintSet<P::Scalar>,
327 eq_ind_challenges: EvalPoint<P::Scalar>,
328 memoized_data: &mut MemoizedData<'b, P>,
329 witness: &MultilinearExtensionIndex<'a, P>,
330 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
331 switchover_fn: impl Fn(usize) -> usize,
332 backend: &'a Backend,
333) -> Result<MLECheckProverWithMeta<'a, P, FDomain, Backend>, Error>
334where
335 P: PackedExtension<FDomain>,
336 P::Scalar: TowerField + ExtensionField<FDomain>,
337 FDomain: Field,
338 Backend: ComputationBackend,
339{
340 let (_, meta) = constraint_set_mlecheck_claim(constraint_set.clone())?;
341 let prover = constraint_set_mlecheck_prover(
342 evaluation_order,
343 constraint_set,
344 &eq_ind_challenges,
345 memoized_data,
346 witness,
347 evaluation_domain_factory,
348 &switchover_fn,
349 backend,
350 )?;
351
352 Ok(MLECheckProverWithMeta { prover, meta })
353}