1use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::bail;
7
8use super::{RegularSumcheckProver, ZerocheckProverImpl};
9use crate::{
10 oracle::{Constraint, ConstraintPredicate, ConstraintSet},
11 polynomial::ArithCircuitPoly,
12 protocols::sumcheck::{
13 constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
14 },
15 witness::{MultilinearExtensionIndex, MultilinearWitness},
16};
17
18pub type OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend> = ZerocheckProverImpl<
19 'a,
20 FDomain,
21 FBase,
22 P,
23 ArithCircuitPoly<FBase>,
24 ArithCircuitPoly<<P as PackedField>::Scalar>,
25 MultilinearWitness<'a, P>,
26 DomainFactory,
27 Backend,
28>;
29
30pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
31 'a,
32 FDomain,
33 P,
34 ArithCircuitPoly<<P as PackedField>::Scalar>,
35 MultilinearWitness<'a, P>,
36 Backend,
37>;
38
39pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, DomainFactory, Backend>(
41 constraints: Vec<Constraint<P::Scalar>>,
42 multilinears: Vec<MultilinearWitness<'a, P>>,
43 domain_factory: DomainFactory,
44 zerocheck_challenges: &[F],
45 backend: &'a Backend,
46) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend>, Error>
47where
48 P: PackedField<Scalar = F>
49 + PackedExtension<F, PackedSubfield = P>
50 + PackedExtension<FDomain>
51 + PackedExtension<FBase>,
52 F: TowerField,
53 FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
54 FDomain: Field,
55 DomainFactory: EvaluationDomainFactory<FDomain>,
56 Backend: ComputationBackend,
57{
58 let mut zeros = Vec::with_capacity(constraints.len());
59
60 for Constraint {
61 composition,
62 predicate,
63 name,
64 } in constraints
65 {
66 let composition_base = composition
67 .try_convert_field::<FBase>()
68 .map_err(|_| Error::CircuitFieldDowncastFailed)?;
69 match predicate {
70 ConstraintPredicate::Zero => {
71 zeros.push((
72 name,
73 ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
74 ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
75 ));
76 }
77 _ => bail!(Error::MixedBatchingNotSupported),
78 }
79 }
80
81 let prover = OracleZerocheckProver::<_, _, FDomain, _, _>::new(
82 multilinears,
83 zeros,
84 zerocheck_challenges,
85 domain_factory,
86 backend,
87 )?;
88
89 Ok(prover)
90}
91
92pub fn constraint_set_sumcheck_prover<'a, FW, PW, FDomain, Backend>(
94 evaluation_order: EvaluationOrder,
95 constraint_set: ConstraintSet<FW>,
96 witness: &MultilinearExtensionIndex<'a, PW>,
97 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
98 switchover_fn: impl Fn(usize) -> usize + Clone,
99 backend: &'a Backend,
100) -> Result<OracleSumcheckProver<'a, FDomain, PW, Backend>, Error>
101where
102 PW: PackedField<Scalar = FW> + PackedExtension<FDomain>,
103 FW: TowerField + ExtensionField<FDomain>,
104 FDomain: Field,
105 Backend: ComputationBackend,
106{
107 let (constraints, multilinears) = split_constraint_set::<FW, PW>(constraint_set, witness)?;
108
109 let mut sums = Vec::new();
110
111 for Constraint {
112 composition,
113 predicate,
114 ..
115 } in constraints
116 {
117 match predicate {
118 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
119 composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
120 sum,
121 }),
122 _ => bail!(Error::MixedBatchingNotSupported),
123 }
124 }
125
126 let prover = RegularSumcheckProver::new(
127 evaluation_order,
128 multilinears,
129 sums,
130 evaluation_domain_factory,
131 switchover_fn,
132 backend,
133 )?;
134
135 Ok(prover)
136}
137
138type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
139
140#[allow(clippy::type_complexity)]
141pub fn split_constraint_set<'a, F, PW>(
142 constraint_set: ConstraintSet<F>,
143 witness: &MultilinearExtensionIndex<'a, PW>,
144) -> Result<ConstraintsAndMultilinears<'a, F, PW>, Error>
145where
146 F: Field,
147 PW: PackedField,
148 PW::Scalar: ExtensionField<F>,
149{
150 let ConstraintSet {
151 oracle_ids,
152 constraints,
153 n_vars,
154 } = constraint_set;
155
156 let multilinears = oracle_ids
157 .iter()
158 .map(|&oracle_id| witness.get_multilin_poly(oracle_id))
159 .collect::<Result<Vec<_>, _>>()?;
160
161 if multilinears
162 .iter()
163 .any(|multilin| multilin.n_vars() != n_vars)
164 {
165 bail!(Error::ConstraintSetNumberOfVariablesMismatch);
166 }
167
168 Ok((constraints, multilinears))
169}
170
171pub struct SumcheckProversWithMetas<'a, PW, FDomain, Backend>
172where
173 PW: PackedField,
174 FDomain: Field,
175 Backend: ComputationBackend,
176{
177 pub provers: Vec<OracleSumcheckProver<'a, FDomain, PW, Backend>>,
178 pub metas: Vec<OracleClaimMeta>,
179}
180
181pub fn constraint_sets_sumcheck_provers_metas<'a, PW, FDomain, Backend>(
183 evaluation_order: EvaluationOrder,
184 constraint_sets: Vec<ConstraintSet<PW::Scalar>>,
185 witness: &MultilinearExtensionIndex<'a, PW>,
186 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
187 switchover_fn: impl Fn(usize) -> usize,
188 backend: &'a Backend,
189) -> Result<SumcheckProversWithMetas<'a, PW, FDomain, Backend>, Error>
190where
191 PW: PackedExtension<FDomain>,
192 PW::Scalar: TowerField + ExtensionField<FDomain>,
193 FDomain: Field,
194 Backend: ComputationBackend,
195{
196 let mut provers = Vec::with_capacity(constraint_sets.len());
197 let mut metas = Vec::with_capacity(constraint_sets.len());
198
199 for constraint_set in constraint_sets {
200 let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
201 let prover = constraint_set_sumcheck_prover(
202 evaluation_order,
203 constraint_set,
204 witness,
205 evaluation_domain_factory.clone(),
206 &switchover_fn,
207 backend,
208 )?;
209 metas.push(meta);
210 provers.push(prover);
211 }
212 Ok(SumcheckProversWithMetas { provers, metas })
213}