binius_core/protocols/sumcheck/
oracles.rs1use std::iter;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_utils::bail;
7
8use super::{BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, ZerocheckClaim};
9use crate::{
10 oracle::{Constraint, ConstraintPredicate, ConstraintSet, OracleId, TypeErasedComposition},
11 polynomial::ArithCircuitPoly,
12 protocols::evalcheck::EvalcheckMultilinearClaim,
13};
14
15#[derive(Debug)]
16pub enum ConcreteClaim<P: PackedField> {
17 Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
18 Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
19}
20
21pub struct OracleClaimMeta {
22 pub n_vars: usize,
23 pub oracle_ids: Vec<OracleId>,
24}
25
26#[allow(clippy::type_complexity)]
29pub fn constraint_set_sumcheck_claim<F: TowerField>(
30 constraint_set: ConstraintSet<F>,
31) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
32 let (constraints, meta) = split_constraint_set(constraint_set);
33 let n_multilinears = meta.oracle_ids.len();
34
35 let mut sums = Vec::new();
36 for Constraint {
37 composition,
38 predicate,
39 ..
40 } in constraints
41 {
42 match predicate {
43 ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
44 composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
45 sum,
46 }),
47 _ => bail!(Error::MixedBatchingNotSupported),
48 }
49 }
50
51 let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
52 Ok((claim, meta))
53}
54
55#[allow(clippy::type_complexity)]
58pub fn constraint_set_zerocheck_claim<F: TowerField>(
59 constraint_set: ConstraintSet<F>,
60) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
61 let (constraints, meta) = split_constraint_set(constraint_set);
62 let n_multilinears = meta.oracle_ids.len();
63
64 let mut zeros = Vec::new();
65 for Constraint {
66 composition,
67 predicate,
68 ..
69 } in constraints
70 {
71 match predicate {
72 ConstraintPredicate::Zero => {
73 zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
74 }
75 _ => bail!(Error::MixedBatchingNotSupported),
76 }
77 }
78
79 let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
80 Ok((claim, meta))
81}
82
83fn split_constraint_set<F: Field>(
84 constraint_set: ConstraintSet<F>,
85) -> (Vec<Constraint<F>>, OracleClaimMeta) {
86 let ConstraintSet {
87 oracle_ids,
88 constraints,
89 n_vars,
90 } = constraint_set;
91 let meta = OracleClaimMeta { n_vars, oracle_ids };
92 (constraints, meta)
93}
94
95pub fn make_eval_claims<F: TowerField>(
97 metas: impl IntoIterator<Item = OracleClaimMeta>,
98 batch_sumcheck_output: BatchSumcheckOutput<F>,
99) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
100 let metas = metas.into_iter().collect::<Vec<_>>();
101 let max_n_vars = metas.first().map_or(0, |meta| meta.n_vars);
102
103 if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
104 bail!(Error::ClaimProofMismatch);
105 }
106
107 if max_n_vars != batch_sumcheck_output.challenges.len() {
108 bail!(Error::ClaimProofMismatch);
109 }
110
111 let mut evalcheck_claims = Vec::new();
112 for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
113 if meta.oracle_ids.len() != prover_evals.len() {
114 bail!(Error::ClaimProofMismatch);
115 }
116
117 for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
118 let eval_point = batch_sumcheck_output.challenges[max_n_vars - meta.n_vars..].to_vec();
119
120 let claim = EvalcheckMultilinearClaim {
121 id: oracle_id,
122 eval_point: eval_point.into(),
123 eval,
124 };
125
126 evalcheck_claims.push(claim);
127 }
128 }
129
130 Ok(evalcheck_claims)
131}
132
133pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
134 pub claims: Vec<SumcheckClaim<F, C>>,
135 pub metas: Vec<OracleClaimMeta>,
136}
137
138pub fn constraint_set_sumcheck_claims<F: TowerField>(
140 constraint_sets: Vec<ConstraintSet<F>>,
141) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
142 let mut claims = Vec::with_capacity(constraint_sets.len());
143 let mut metas = Vec::with_capacity(constraint_sets.len());
144
145 for constraint_set in constraint_sets {
146 let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
147 metas.push(meta);
148 claims.push(claim);
149 }
150 Ok(SumcheckClaimsWithMeta { claims, metas })
151}