1use binius_field::{
4 ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
5};
6use binius_hal::ComputationBackend;
7use binius_math::{EvaluationDomainFactory, EvaluationOrder};
8use binius_utils::bail;
9
10use super::{RegularSumcheckProver, UnivariateZerocheck};
11use crate::{
12 oracle::{Constraint, ConstraintPredicate, ConstraintSet},
13 polynomial::ArithCircuitPoly,
14 protocols::sumcheck::{
15 constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
16 },
17 witness::{MultilinearExtensionIndex, MultilinearWitness},
18};
19
20pub type OracleZerocheckProver<
21 'a,
22 P,
23 FBase,
24 FDomain,
25 InterpolationDomainFactory,
26 SwitchoverFn,
27 Backend,
28> = UnivariateZerocheck<
29 'a,
30 FDomain,
31 FBase,
32 P,
33 ArithCircuitPoly<FBase>,
34 ArithCircuitPoly<<P as PackedField>::Scalar>,
35 MultilinearWitness<'a, P>,
36 InterpolationDomainFactory,
37 SwitchoverFn,
38 Backend,
39>;
40
41pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
42 'a,
43 FDomain,
44 P,
45 ArithCircuitPoly<<P as PackedField>::Scalar>,
46 MultilinearWitness<'a, P>,
47 Backend,
48>;
49
50pub fn constraint_set_zerocheck_prover<
52 'a,
53 P,
54 F,
55 FBase,
56 FDomain,
57 InterpolationDomainFactory,
58 SwitchoverFn,
59 Backend,
60>(
61 constraints: Vec<Constraint<P::Scalar>>,
62 multilinears: Vec<MultilinearWitness<'a, P>>,
63 interpolation_domain_factory: InterpolationDomainFactory,
64 switchover_fn: SwitchoverFn,
65 zerocheck_challenges: &[F],
66 backend: &'a Backend,
67) -> Result<
68 OracleZerocheckProver<'a, P, FBase, FDomain, InterpolationDomainFactory, SwitchoverFn, Backend>,
69 Error,
70>
71where
72 P: PackedFieldIndexable<Scalar = F>
73 + PackedExtension<F, PackedSubfield = P>
74 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
75 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
76 F: TowerField,
77 FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
78 FDomain: Field,
79 InterpolationDomainFactory: EvaluationDomainFactory<FDomain>,
80 SwitchoverFn: Fn(usize) -> usize + Clone,
81 Backend: ComputationBackend,
82{
83 let mut zeros = Vec::with_capacity(constraints.len());
84
85 for Constraint {
86 composition,
87 predicate,
88 name,
89 } in constraints
90 {
91 let composition_base = composition
92 .clone()
93 .try_convert_field::<FBase>()
94 .map_err(|_| Error::CircuitFieldDowncastFailed)?;
95 match predicate {
96 ConstraintPredicate::Zero => {
97 zeros.push((
98 name,
99 ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
100 ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
101 ));
102 }
103 _ => bail!(Error::MixedBatchingNotSupported),
104 }
105 }
106
107 let prover = OracleZerocheckProver::<_, _, FDomain, _, _, _>::new(
108 multilinears,
109 zeros,
110 zerocheck_challenges,
111 interpolation_domain_factory,
112 switchover_fn,
113 backend,
114 )?;
115
116 Ok(prover)
117}
118
119pub fn constraint_set_sumcheck_prover<'a, FW, PW, FDomain, Backend>(
121 evaluation_order: EvaluationOrder,
122 constraint_set: ConstraintSet<FW>,
123 witness: &MultilinearExtensionIndex<'a, PW>,
124 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
125 switchover_fn: impl Fn(usize) -> usize + Clone,
126 backend: &'a Backend,
127) -> Result<OracleSumcheckProver<'a, FDomain, PW, Backend>, Error>
128where
129 PW: PackedField<Scalar = FW> + PackedExtension<FDomain>,
130 FW: TowerField + ExtensionField<FDomain>,
131 FDomain: Field,
132 Backend: ComputationBackend,
133{
134 let (constraints, multilinears) = split_constraint_set::<FW, PW>(constraint_set, witness)?;
135
136 let mut sums = Vec::new();
137
138 for Constraint {
139 composition,
140 predicate,
141 ..
142 } in constraints
143 {
144 match predicate {
145 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
146 composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
147 sum,
148 }),
149 _ => bail!(Error::MixedBatchingNotSupported),
150 }
151 }
152
153 let prover = RegularSumcheckProver::new(
154 evaluation_order,
155 multilinears,
156 sums,
157 evaluation_domain_factory,
158 switchover_fn,
159 backend,
160 )?;
161
162 Ok(prover)
163}
164
165type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
166
167#[allow(clippy::type_complexity)]
168pub fn split_constraint_set<'a, F, PW>(
169 constraint_set: ConstraintSet<F>,
170 witness: &MultilinearExtensionIndex<'a, PW>,
171) -> Result<ConstraintsAndMultilinears<'a, F, PW>, Error>
172where
173 F: Field,
174 PW: PackedField,
175 PW::Scalar: ExtensionField<F>,
176{
177 let ConstraintSet {
178 oracle_ids,
179 constraints,
180 n_vars,
181 } = constraint_set;
182
183 let multilinears = oracle_ids
184 .iter()
185 .map(|&oracle_id| witness.get_multilin_poly(oracle_id))
186 .collect::<Result<Vec<_>, _>>()?;
187
188 if multilinears
189 .iter()
190 .any(|multilin| multilin.n_vars() != n_vars)
191 {
192 bail!(Error::ConstraintSetNumberOfVariablesMismatch);
193 }
194
195 Ok((constraints, multilinears))
196}
197
198pub struct SumcheckProversWithMetas<'a, PW, FDomain, Backend>
199where
200 PW: PackedField,
201 FDomain: Field,
202 Backend: ComputationBackend,
203{
204 pub provers: Vec<OracleSumcheckProver<'a, FDomain, PW, Backend>>,
205 pub metas: Vec<OracleClaimMeta>,
206}
207
208pub fn constraint_sets_sumcheck_provers_metas<'a, PW, FDomain, Backend>(
210 evaluation_order: EvaluationOrder,
211 constraint_sets: Vec<ConstraintSet<PW::Scalar>>,
212 witness: &MultilinearExtensionIndex<'a, PW>,
213 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
214 switchover_fn: impl Fn(usize) -> usize,
215 backend: &'a Backend,
216) -> Result<SumcheckProversWithMetas<'a, PW, FDomain, Backend>, Error>
217where
218 PW: PackedExtension<FDomain>,
219 PW::Scalar: TowerField + ExtensionField<FDomain>,
220 FDomain: Field,
221 Backend: ComputationBackend,
222{
223 let mut provers = Vec::with_capacity(constraint_sets.len());
224 let mut metas = Vec::with_capacity(constraint_sets.len());
225
226 for constraint_set in constraint_sets {
227 let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
228 let prover = constraint_set_sumcheck_prover(
229 evaluation_order,
230 constraint_set,
231 witness,
232 evaluation_domain_factory.clone(),
233 &switchover_fn,
234 backend,
235 )?;
236 metas.push(meta);
237 provers.push(prover);
238 }
239 Ok(SumcheckProversWithMetas { provers, metas })
240}