binius_core/protocols/sumcheck/
oracles.rs1use std::iter;
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::EvaluationOrder;
8use binius_utils::{bail, sorting::is_sorted_ascending};
9
10use super::{
11 BatchSumcheckOutput, BatchZerocheckOutput, CompositeSumClaim, EqIndSumcheckClaim, Error,
12 SumcheckClaim, ZerocheckClaim,
13};
14use crate::{
15 oracle::{
16 Constraint, ConstraintPredicate, OracleId, SizedConstraintSet, TypeErasedComposition,
17 },
18 protocols::evalcheck::EvalcheckMultilinearClaim,
19};
20
21#[derive(Debug)]
22pub enum ConcreteClaim<P: PackedField> {
23 Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
24 Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
25}
26
27pub struct OracleClaimMeta {
28 pub n_vars: usize,
29 pub oracle_ids: Vec<OracleId>,
30}
31
32#[allow(clippy::type_complexity)]
35pub fn constraint_set_sumcheck_claim<F: TowerField>(
36 constraint_set: SizedConstraintSet<F>,
37) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
38 let (constraints, meta) = split_constraint_set(constraint_set);
39 let n_multilinears = meta.oracle_ids.len();
40
41 let mut sums = Vec::new();
42 for Constraint {
43 composition,
44 predicate,
45 ..
46 } in constraints
47 {
48 match predicate {
49 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
50 composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
51 sum,
52 }),
53 _ => bail!(Error::MixedBatchingNotSupported),
54 }
55 }
56
57 let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
58 Ok((claim, meta))
59}
60
61#[allow(clippy::type_complexity)]
64pub fn constraint_set_zerocheck_claim<F: TowerField>(
65 constraint_set: SizedConstraintSet<F>,
66) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
67 let (constraints, meta) = split_constraint_set(constraint_set);
68 let n_multilinears = meta.oracle_ids.len();
69
70 let mut zeros = Vec::new();
71 for Constraint {
72 composition,
73 predicate,
74 ..
75 } in constraints
76 {
77 match predicate {
78 ConstraintPredicate::Zero => {
79 zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
80 }
81 _ => bail!(Error::MixedBatchingNotSupported),
82 }
83 }
84
85 let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
86 Ok((claim, meta))
87}
88
89#[allow(clippy::type_complexity)]
90pub fn constraint_set_mlecheck_claim<F: TowerField>(
91 constraint_set: SizedConstraintSet<F>,
92) -> Result<(EqIndSumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
93 let (constraints, meta) = split_constraint_set(constraint_set);
94 let n_multilinears = meta.oracle_ids.len();
95
96 let mut sums = Vec::new();
97 for Constraint {
98 composition,
99 predicate,
100 ..
101 } in constraints
102 {
103 match predicate {
104 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
105 composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
106 sum,
107 }),
108 _ => bail!(Error::MixedBatchingNotSupported),
109 }
110 }
111
112 let claim = EqIndSumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
113 Ok((claim, meta))
114}
115
116fn split_constraint_set<F: Field>(
117 constraint_set: SizedConstraintSet<F>,
118) -> (Vec<Constraint<F>>, OracleClaimMeta) {
119 let SizedConstraintSet {
120 oracle_ids,
121 constraints,
122 n_vars,
123 } = constraint_set;
124 let meta = OracleClaimMeta { n_vars, oracle_ids };
125 (constraints, meta)
126}
127
128pub fn make_eval_claims<F: TowerField>(
130 evaluation_order: EvaluationOrder,
131 metas: impl IntoIterator<Item = OracleClaimMeta>,
132 batch_sumcheck_output: BatchSumcheckOutput<F>,
133) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
134 let metas = metas.into_iter().collect::<Vec<_>>();
135
136 if !is_sorted_ascending(metas.iter().map(|meta| meta.n_vars)) {
137 bail!(Error::ClaimsOutOfOrder);
138 }
139
140 let max_n_vars = metas.last().map_or(0, |meta| meta.n_vars);
141
142 if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
143 bail!(Error::ClaimProofMismatch);
144 }
145
146 if max_n_vars != batch_sumcheck_output.challenges.len() {
147 bail!(Error::ClaimProofMismatch);
148 }
149
150 let mut evalcheck_claims = Vec::new();
151 for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
152 if meta.oracle_ids.len() != prover_evals.len() {
153 bail!(Error::ClaimProofMismatch);
154 }
155
156 for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
157 let eval_points_range = match evaluation_order {
158 EvaluationOrder::LowToHigh => 0..meta.n_vars,
159 EvaluationOrder::HighToLow => max_n_vars - meta.n_vars..max_n_vars,
160 };
161 let eval_point = batch_sumcheck_output.challenges[eval_points_range].to_vec();
162
163 let claim = EvalcheckMultilinearClaim {
164 id: oracle_id,
165 eval_point: eval_point.into(),
166 eval,
167 };
168
169 evalcheck_claims.push(claim);
170 }
171 }
172
173 Ok(evalcheck_claims)
174}
175
176pub fn make_zerocheck_eval_claims<F: Field>(
178 metas: impl IntoIterator<Item = OracleClaimMeta>,
179 batch_zerocheck_output: BatchZerocheckOutput<F>,
180) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
181 let BatchZerocheckOutput {
182 skipped_challenges,
183 unskipped_challenges,
184 concat_multilinear_evals,
185 } = batch_zerocheck_output;
186
187 let metas = metas.into_iter().collect::<Vec<_>>();
188
189 if !is_sorted_ascending(metas.iter().map(|meta| meta.n_vars)) {
190 bail!(Error::ClaimsOutOfOrder);
191 }
192
193 let max_n_vars = metas.last().map_or(0, |meta| meta.n_vars);
194 let n_multilinears = metas
195 .iter()
196 .map(|meta| meta.oracle_ids.len())
197 .sum::<usize>();
198
199 if n_multilinears != concat_multilinear_evals.len() {
200 bail!(Error::ClaimProofMismatch);
201 }
202
203 if max_n_vars != skipped_challenges.len() + unskipped_challenges.len() {
204 bail!(Error::ClaimProofMismatch);
205 }
206
207 let ids_with_n_vars = metas.into_iter().flat_map(|meta| {
208 meta.oracle_ids
209 .into_iter()
210 .map(move |oracle_id| (oracle_id, meta.n_vars))
211 });
212
213 let mut evalcheck_claims = Vec::new();
214 for ((oracle_id, n_vars), eval) in iter::zip(ids_with_n_vars, concat_multilinear_evals) {
215 let eval_point = [
218 &skipped_challenges[..n_vars.min(skipped_challenges.len())],
219 &unskipped_challenges[(max_n_vars - n_vars).min(unskipped_challenges.len())..],
220 ]
221 .concat()
222 .into();
223
224 let claim = EvalcheckMultilinearClaim {
225 id: oracle_id,
226 eval_point,
227 eval,
228 };
229
230 evalcheck_claims.push(claim);
231 }
232
233 Ok(evalcheck_claims)
234}
235
236pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
237 pub claims: Vec<SumcheckClaim<F, C>>,
238 pub metas: Vec<OracleClaimMeta>,
239}
240
241pub fn constraint_set_sumcheck_claims<F: TowerField>(
243 constraint_sets: Vec<SizedConstraintSet<F>>,
244) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
245 let mut claims = Vec::with_capacity(constraint_sets.len());
246 let mut metas = Vec::with_capacity(constraint_sets.len());
247
248 for constraint_set in constraint_sets {
249 let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
250 metas.push(meta);
251 claims.push(claim);
252 }
253 Ok(SumcheckClaimsWithMeta { claims, metas })
254}
255
256pub struct MLEcheckClaimsWithMeta<F: TowerField, C> {
257 pub claims: Vec<EqIndSumcheckClaim<F, C>>,
258 pub metas: Vec<OracleClaimMeta>,
259}
260
261pub fn constraint_set_mlecheck_claims<F: TowerField>(
263 constraint_sets: Vec<SizedConstraintSet<F>>,
264) -> Result<MLEcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
265 let mut claims = Vec::with_capacity(constraint_sets.len());
266 let mut metas = Vec::with_capacity(constraint_sets.len());
267
268 for constraint_set in constraint_sets {
269 let (claim, meta) = constraint_set_mlecheck_claim(constraint_set)?;
270 metas.push(meta);
271 claims.push(claim);
272 }
273 Ok(MLEcheckClaimsWithMeta { claims, metas })
274}