binius_core/protocols/sumcheck/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_utils::bail;
7
8use super::{BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, ZerocheckClaim};
9use crate::{
10	oracle::{Constraint, ConstraintPredicate, ConstraintSet, OracleId, TypeErasedComposition},
11	polynomial::ArithCircuitPoly,
12	protocols::evalcheck::EvalcheckMultilinearClaim,
13};
14
15#[derive(Debug)]
16pub enum ConcreteClaim<P: PackedField> {
17	Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
18	Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
19}
20
21pub struct OracleClaimMeta {
22	pub n_vars: usize,
23	pub oracle_ids: Vec<OracleId>,
24}
25
26/// Create a sumcheck claim out of constraint set. Fails when the constraint set contains zerochecks.
27/// Returns claim and metadata used for evalcheck claim construction.
28#[allow(clippy::type_complexity)]
29pub fn constraint_set_sumcheck_claim<F: TowerField>(
30	constraint_set: ConstraintSet<F>,
31) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
32	let (constraints, meta) = split_constraint_set(constraint_set);
33	let n_multilinears = meta.oracle_ids.len();
34
35	let mut sums = Vec::new();
36	for Constraint {
37		composition,
38		predicate,
39		..
40	} in constraints
41	{
42		match predicate {
43			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
44				composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
45				sum,
46			}),
47			_ => bail!(Error::MixedBatchingNotSupported),
48		}
49	}
50
51	let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
52	Ok((claim, meta))
53}
54
55/// Create a zerocheck claim from the constraint set. Fails when the constraint set contains regular sumchecks.
56/// Returns claim and metadata used for evalcheck claim construction.
57#[allow(clippy::type_complexity)]
58pub fn constraint_set_zerocheck_claim<F: TowerField>(
59	constraint_set: ConstraintSet<F>,
60) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
61	let (constraints, meta) = split_constraint_set(constraint_set);
62	let n_multilinears = meta.oracle_ids.len();
63
64	let mut zeros = Vec::new();
65	for Constraint {
66		composition,
67		predicate,
68		..
69	} in constraints
70	{
71		match predicate {
72			ConstraintPredicate::Zero => {
73				zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
74			}
75			_ => bail!(Error::MixedBatchingNotSupported),
76		}
77	}
78
79	let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
80	Ok((claim, meta))
81}
82
83fn split_constraint_set<F: Field>(
84	constraint_set: ConstraintSet<F>,
85) -> (Vec<Constraint<F>>, OracleClaimMeta) {
86	let ConstraintSet {
87		oracle_ids,
88		constraints,
89		n_vars,
90	} = constraint_set;
91	let meta = OracleClaimMeta { n_vars, oracle_ids };
92	(constraints, meta)
93}
94
95/// Constructs evalcheck claims from metadata returned by constraint set claim constructors.
96pub fn make_eval_claims<F: TowerField>(
97	metas: impl IntoIterator<Item = OracleClaimMeta>,
98	batch_sumcheck_output: BatchSumcheckOutput<F>,
99) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
100	let metas = metas.into_iter().collect::<Vec<_>>();
101	let max_n_vars = metas.first().map_or(0, |meta| meta.n_vars);
102
103	if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
104		bail!(Error::ClaimProofMismatch);
105	}
106
107	if max_n_vars != batch_sumcheck_output.challenges.len() {
108		bail!(Error::ClaimProofMismatch);
109	}
110
111	let mut evalcheck_claims = Vec::new();
112	for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
113		if meta.oracle_ids.len() != prover_evals.len() {
114			bail!(Error::ClaimProofMismatch);
115		}
116
117		for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
118			let eval_point = batch_sumcheck_output.challenges[max_n_vars - meta.n_vars..].to_vec();
119
120			let claim = EvalcheckMultilinearClaim {
121				id: oracle_id,
122				eval_point: eval_point.into(),
123				eval,
124			};
125
126			evalcheck_claims.push(claim);
127		}
128	}
129
130	Ok(evalcheck_claims)
131}
132
133pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
134	pub claims: Vec<SumcheckClaim<F, C>>,
135	pub metas: Vec<OracleClaimMeta>,
136}
137
138/// Constructs sumcheck claims and metas from the vector of [`ConstraintSet`]
139pub fn constraint_set_sumcheck_claims<F: TowerField>(
140	constraint_sets: Vec<ConstraintSet<F>>,
141) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
142	let mut claims = Vec::with_capacity(constraint_sets.len());
143	let mut metas = Vec::with_capacity(constraint_sets.len());
144
145	for constraint_set in constraint_sets {
146		let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
147		metas.push(meta);
148		claims.push(claim);
149	}
150	Ok(SumcheckClaimsWithMeta { claims, metas })
151}