binius_core/protocols/sumcheck/
oracles.rs1use std::iter;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_math::EvaluationOrder;
7use binius_utils::bail;
8
9use super::{BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, ZerocheckClaim};
10use crate::{
11 oracle::{Constraint, ConstraintPredicate, ConstraintSet, OracleId, TypeErasedComposition},
12 polynomial::ArithCircuitPoly,
13 protocols::evalcheck::EvalcheckMultilinearClaim,
14};
15
16#[derive(Debug)]
17pub enum ConcreteClaim<P: PackedField> {
18 Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
19 Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
20}
21
22pub struct OracleClaimMeta {
23 pub n_vars: usize,
24 pub oracle_ids: Vec<OracleId>,
25}
26
27#[allow(clippy::type_complexity)]
30pub fn constraint_set_sumcheck_claim<F: TowerField>(
31 constraint_set: ConstraintSet<F>,
32) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
33 let (constraints, meta) = split_constraint_set(constraint_set);
34 let n_multilinears = meta.oracle_ids.len();
35
36 let mut sums = Vec::new();
37 for Constraint {
38 composition,
39 predicate,
40 ..
41 } in constraints
42 {
43 match predicate {
44 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
45 composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
46 sum,
47 }),
48 _ => bail!(Error::MixedBatchingNotSupported),
49 }
50 }
51
52 let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
53 Ok((claim, meta))
54}
55
56#[allow(clippy::type_complexity)]
59pub fn constraint_set_zerocheck_claim<F: TowerField>(
60 constraint_set: ConstraintSet<F>,
61) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
62 let (constraints, meta) = split_constraint_set(constraint_set);
63 let n_multilinears = meta.oracle_ids.len();
64
65 let mut zeros = Vec::new();
66 for Constraint {
67 composition,
68 predicate,
69 ..
70 } in constraints
71 {
72 match predicate {
73 ConstraintPredicate::Zero => {
74 zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
75 }
76 _ => bail!(Error::MixedBatchingNotSupported),
77 }
78 }
79
80 let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
81 Ok((claim, meta))
82}
83
84fn split_constraint_set<F: Field>(
85 constraint_set: ConstraintSet<F>,
86) -> (Vec<Constraint<F>>, OracleClaimMeta) {
87 let ConstraintSet {
88 oracle_ids,
89 constraints,
90 n_vars,
91 } = constraint_set;
92 let meta = OracleClaimMeta { n_vars, oracle_ids };
93 (constraints, meta)
94}
95
96pub fn make_eval_claims<F: TowerField>(
98 evaluation_order: EvaluationOrder,
99 metas: impl IntoIterator<Item = OracleClaimMeta>,
100 batch_sumcheck_output: BatchSumcheckOutput<F>,
101) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
102 let metas = metas.into_iter().collect::<Vec<_>>();
103 let max_n_vars = metas.first().map_or(0, |meta| meta.n_vars);
104
105 if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
106 bail!(Error::ClaimProofMismatch);
107 }
108
109 if max_n_vars != batch_sumcheck_output.challenges.len() {
110 bail!(Error::ClaimProofMismatch);
111 }
112
113 let mut evalcheck_claims = Vec::new();
114 for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
115 if meta.oracle_ids.len() != prover_evals.len() {
116 bail!(Error::ClaimProofMismatch);
117 }
118
119 for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
120 let eval_points_range = match evaluation_order {
121 EvaluationOrder::LowToHigh => max_n_vars - meta.n_vars..max_n_vars,
122 EvaluationOrder::HighToLow => 0..meta.n_vars,
123 };
124 let eval_point = batch_sumcheck_output.challenges[eval_points_range].to_vec();
125
126 let claim = EvalcheckMultilinearClaim {
127 id: oracle_id,
128 eval_point: eval_point.into(),
129 eval,
130 };
131
132 evalcheck_claims.push(claim);
133 }
134 }
135
136 Ok(evalcheck_claims)
137}
138
139pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
140 pub claims: Vec<SumcheckClaim<F, C>>,
141 pub metas: Vec<OracleClaimMeta>,
142}
143
144pub fn constraint_set_sumcheck_claims<F: TowerField>(
146 constraint_sets: Vec<ConstraintSet<F>>,
147) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
148 let mut claims = Vec::with_capacity(constraint_sets.len());
149 let mut metas = Vec::with_capacity(constraint_sets.len());
150
151 for constraint_set in constraint_sets {
152 let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
153 metas.push(meta);
154 claims.push(claim);
155 }
156 Ok(SumcheckClaimsWithMeta { claims, metas })
157}