1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 underlier::UnderlierType,
6 ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{EvaluationDomainFactory, EvaluationOrder};
10use binius_utils::bail;
11
12use super::{RegularSumcheckProver, UnivariateZerocheck};
13use crate::{
14 oracle::{Constraint, ConstraintPredicate, ConstraintSet},
15 polynomial::ArithCircuitPoly,
16 protocols::sumcheck::{
17 constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
18 },
19 witness::{MultilinearExtensionIndex, MultilinearWitness},
20};
21
22pub type OracleZerocheckProver<'a, P, FBase, FDomain, Backend> = UnivariateZerocheck<
23 'a,
24 'a,
25 FDomain,
26 FBase,
27 P,
28 ArithCircuitPoly<FBase>,
29 ArithCircuitPoly<<P as PackedField>::Scalar>,
30 MultilinearWitness<'a, P>,
31 Backend,
32>;
33
34pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
35 'a,
36 FDomain,
37 P,
38 ArithCircuitPoly<<P as PackedField>::Scalar>,
39 MultilinearWitness<'a, P>,
40 Backend,
41>;
42
43pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, Backend>(
45 constraints: Vec<Constraint<P::Scalar>>,
46 multilinears: Vec<MultilinearWitness<'a, P>>,
47 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
48 switchover_fn: impl Fn(usize) -> usize + Clone,
49 zerocheck_challenges: &[P::Scalar],
50 backend: &'a Backend,
51) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, Backend>, Error>
52where
53 P: PackedFieldIndexable<Scalar = F>
54 + PackedExtension<F, PackedSubfield = P>
55 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
56 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
57 F: TowerField,
58 FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
59 FDomain: Field,
60 Backend: ComputationBackend,
61{
62 let mut zeros = Vec::with_capacity(constraints.len());
63
64 for Constraint {
65 composition,
66 predicate,
67 name,
68 } in constraints
69 {
70 let composition_base = composition
71 .clone()
72 .try_convert_field::<FBase>()
73 .map_err(|_| Error::CircuitFieldDowncastFailed)?;
74 match predicate {
75 ConstraintPredicate::Zero => {
76 zeros.push((
77 name,
78 ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
79 ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
80 ));
81 }
82 _ => bail!(Error::MixedBatchingNotSupported),
83 }
84 }
85
86 let prover = OracleZerocheckProver::<_, _, FDomain, _>::new(
87 multilinears,
88 zeros,
89 zerocheck_challenges,
90 evaluation_domain_factory,
91 switchover_fn,
92 backend,
93 )?;
94
95 Ok(prover)
96}
97
98pub fn constraint_set_sumcheck_prover<'a, U, FW, FDomain, Backend>(
100 evaluation_order: EvaluationOrder,
101 constraint_set: ConstraintSet<FW>,
102 witness: &MultilinearExtensionIndex<'a, U, FW>,
103 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
104 switchover_fn: impl Fn(usize) -> usize + Clone,
105 backend: &'a Backend,
106) -> Result<OracleSumcheckProver<'a, FDomain, PackedType<U, FW>, Backend>, Error>
107where
108 U: UnderlierType + PackScalar<FW> + PackScalar<FDomain>,
109 FW: TowerField + ExtensionField<FDomain>,
110 FDomain: Field,
111 Backend: ComputationBackend,
112{
113 let (constraints, multilinears) = split_constraint_set::<_, FW, _>(constraint_set, witness)?;
114
115 let mut sums = Vec::new();
116
117 for Constraint {
118 composition,
119 predicate,
120 ..
121 } in constraints
122 {
123 match predicate {
124 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
125 composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
126 sum,
127 }),
128 _ => bail!(Error::MixedBatchingNotSupported),
129 }
130 }
131
132 let prover = RegularSumcheckProver::new(
133 evaluation_order,
134 multilinears,
135 sums,
136 evaluation_domain_factory,
137 switchover_fn,
138 backend,
139 )?;
140
141 Ok(prover)
142}
143
144type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
145
146#[allow(clippy::type_complexity)]
147pub fn split_constraint_set<'a, U, F, FW>(
148 constraint_set: ConstraintSet<F>,
149 witness: &MultilinearExtensionIndex<'a, U, FW>,
150) -> Result<ConstraintsAndMultilinears<'a, F, PackedType<U, FW>>, Error>
151where
152 U: UnderlierType + PackScalar<F> + PackScalar<FW>,
153 F: Field,
154 FW: ExtensionField<F>,
155{
156 let ConstraintSet {
157 oracle_ids,
158 constraints,
159 n_vars,
160 } = constraint_set;
161
162 let multilinears = oracle_ids
163 .iter()
164 .map(|&oracle_id| witness.get_multilin_poly(oracle_id))
165 .collect::<Result<Vec<_>, _>>()?;
166
167 if multilinears
168 .iter()
169 .any(|multilin| multilin.n_vars() != n_vars)
170 {
171 bail!(Error::ConstraintSetNumberOfVariablesMismatch);
172 }
173
174 Ok((constraints, multilinears))
175}
176
177pub struct SumcheckProversWithMetas<'a, U, FW, FDomain, Backend>
178where
179 U: UnderlierType + PackScalar<FW>,
180 FW: TowerField,
181 FDomain: Field,
182 Backend: ComputationBackend,
183{
184 pub provers: Vec<OracleSumcheckProver<'a, FDomain, PackedType<U, FW>, Backend>>,
185 pub metas: Vec<OracleClaimMeta>,
186}
187
188pub fn constraint_sets_sumcheck_provers_metas<'a, U, FW, FDomain, Backend>(
190 constraint_sets: Vec<ConstraintSet<FW>>,
191 witness: &MultilinearExtensionIndex<'a, U, FW>,
192 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
193 switchover_fn: impl Fn(usize) -> usize,
194 backend: &'a Backend,
195) -> Result<SumcheckProversWithMetas<'a, U, FW, FDomain, Backend>, Error>
196where
197 U: UnderlierType + PackScalar<FW> + PackScalar<FDomain>,
198 FW: TowerField + ExtensionField<FDomain>,
199 FDomain: Field,
200 Backend: ComputationBackend,
201{
202 let mut provers = Vec::with_capacity(constraint_sets.len());
203 let mut metas = Vec::with_capacity(constraint_sets.len());
204
205 for constraint_set in constraint_sets {
206 let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
207 let prover = constraint_set_sumcheck_prover(
208 EvaluationOrder::LowToHigh,
209 constraint_set,
210 witness,
211 evaluation_domain_factory.clone(),
212 &switchover_fn,
213 backend,
214 )?;
215 metas.push(meta);
216 provers.push(prover);
217 }
218 Ok(SumcheckProversWithMetas { provers, metas })
219}