binius_core/protocols/sumcheck/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_math::EvaluationOrder;
7use binius_utils::bail;
8
9use super::{BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, ZerocheckClaim};
10use crate::{
11	oracle::{Constraint, ConstraintPredicate, ConstraintSet, OracleId, TypeErasedComposition},
12	polynomial::ArithCircuitPoly,
13	protocols::evalcheck::EvalcheckMultilinearClaim,
14};
15
16#[derive(Debug)]
17pub enum ConcreteClaim<P: PackedField> {
18	Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
19	Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
20}
21
22pub struct OracleClaimMeta {
23	pub n_vars: usize,
24	pub oracle_ids: Vec<OracleId>,
25}
26
27/// Create a sumcheck claim out of constraint set. Fails when the constraint set contains zerochecks.
28/// Returns claim and metadata used for evalcheck claim construction.
29#[allow(clippy::type_complexity)]
30pub fn constraint_set_sumcheck_claim<F: TowerField>(
31	constraint_set: ConstraintSet<F>,
32) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
33	let (constraints, meta) = split_constraint_set(constraint_set);
34	let n_multilinears = meta.oracle_ids.len();
35
36	let mut sums = Vec::new();
37	for Constraint {
38		composition,
39		predicate,
40		..
41	} in constraints
42	{
43		match predicate {
44			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
45				composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
46				sum,
47			}),
48			_ => bail!(Error::MixedBatchingNotSupported),
49		}
50	}
51
52	let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
53	Ok((claim, meta))
54}
55
56/// Create a zerocheck claim from the constraint set. Fails when the constraint set contains regular sumchecks.
57/// Returns claim and metadata used for evalcheck claim construction.
58#[allow(clippy::type_complexity)]
59pub fn constraint_set_zerocheck_claim<F: TowerField>(
60	constraint_set: ConstraintSet<F>,
61) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
62	let (constraints, meta) = split_constraint_set(constraint_set);
63	let n_multilinears = meta.oracle_ids.len();
64
65	let mut zeros = Vec::new();
66	for Constraint {
67		composition,
68		predicate,
69		..
70	} in constraints
71	{
72		match predicate {
73			ConstraintPredicate::Zero => {
74				zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
75			}
76			_ => bail!(Error::MixedBatchingNotSupported),
77		}
78	}
79
80	let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
81	Ok((claim, meta))
82}
83
84fn split_constraint_set<F: Field>(
85	constraint_set: ConstraintSet<F>,
86) -> (Vec<Constraint<F>>, OracleClaimMeta) {
87	let ConstraintSet {
88		oracle_ids,
89		constraints,
90		n_vars,
91	} = constraint_set;
92	let meta = OracleClaimMeta { n_vars, oracle_ids };
93	(constraints, meta)
94}
95
96/// Constructs evalcheck claims from metadata returned by constraint set claim constructors.
97pub fn make_eval_claims<F: TowerField>(
98	evaluation_order: EvaluationOrder,
99	metas: impl IntoIterator<Item = OracleClaimMeta>,
100	batch_sumcheck_output: BatchSumcheckOutput<F>,
101) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
102	let metas = metas.into_iter().collect::<Vec<_>>();
103	let max_n_vars = metas.first().map_or(0, |meta| meta.n_vars);
104
105	if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
106		bail!(Error::ClaimProofMismatch);
107	}
108
109	if max_n_vars != batch_sumcheck_output.challenges.len() {
110		bail!(Error::ClaimProofMismatch);
111	}
112
113	let mut evalcheck_claims = Vec::new();
114	for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
115		if meta.oracle_ids.len() != prover_evals.len() {
116			bail!(Error::ClaimProofMismatch);
117		}
118
119		for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
120			let eval_points_range = match evaluation_order {
121				EvaluationOrder::LowToHigh => max_n_vars - meta.n_vars..max_n_vars,
122				EvaluationOrder::HighToLow => 0..meta.n_vars,
123			};
124			let eval_point = batch_sumcheck_output.challenges[eval_points_range].to_vec();
125
126			let claim = EvalcheckMultilinearClaim {
127				id: oracle_id,
128				eval_point: eval_point.into(),
129				eval,
130			};
131
132			evalcheck_claims.push(claim);
133		}
134	}
135
136	Ok(evalcheck_claims)
137}
138
139pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
140	pub claims: Vec<SumcheckClaim<F, C>>,
141	pub metas: Vec<OracleClaimMeta>,
142}
143
144/// Constructs sumcheck claims and metas from the vector of [`ConstraintSet`]
145pub fn constraint_set_sumcheck_claims<F: TowerField>(
146	constraint_sets: Vec<ConstraintSet<F>>,
147) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
148	let mut claims = Vec::with_capacity(constraint_sets.len());
149	let mut metas = Vec::with_capacity(constraint_sets.len());
150
151	for constraint_set in constraint_sets {
152		let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
153		metas.push(meta);
154		claims.push(claim);
155	}
156	Ok(SumcheckClaimsWithMeta { claims, metas })
157}