binius_core/protocols/sumcheck/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::EvaluationOrder;
8use binius_utils::{bail, sorting::is_sorted_ascending};
9
10use super::{
11	BatchSumcheckOutput, BatchZerocheckOutput, CompositeSumClaim, EqIndSumcheckClaim, Error,
12	SumcheckClaim, ZerocheckClaim,
13};
14use crate::{
15	oracle::{
16		Constraint, ConstraintPredicate, OracleId, SizedConstraintSet, TypeErasedComposition,
17	},
18	protocols::evalcheck::EvalcheckMultilinearClaim,
19};
20
21#[derive(Debug)]
22pub enum ConcreteClaim<P: PackedField> {
23	Sumcheck(SumcheckClaim<P::Scalar, TypeErasedComposition<P>>),
24	Zerocheck(ZerocheckClaim<P::Scalar, TypeErasedComposition<P>>),
25}
26
27pub struct OracleClaimMeta {
28	pub n_vars: usize,
29	pub oracle_ids: Vec<OracleId>,
30}
31
32/// Create a sumcheck claim out of constraint set. Fails when the constraint set contains
33/// zerochecks. Returns claim and metadata used for evalcheck claim construction.
34#[allow(clippy::type_complexity)]
35pub fn constraint_set_sumcheck_claim<F: TowerField>(
36	constraint_set: SizedConstraintSet<F>,
37) -> Result<(SumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
38	let (constraints, meta) = split_constraint_set(constraint_set);
39	let n_multilinears = meta.oracle_ids.len();
40
41	let mut sums = Vec::new();
42	for Constraint {
43		composition,
44		predicate,
45		..
46	} in constraints
47	{
48		match predicate {
49			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
50				composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
51				sum,
52			}),
53			_ => bail!(Error::MixedBatchingNotSupported),
54		}
55	}
56
57	let claim = SumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
58	Ok((claim, meta))
59}
60
61/// Create a zerocheck claim from the constraint set. Fails when the constraint set contains regular
62/// sumchecks. Returns claim and metadata used for evalcheck claim construction.
63#[allow(clippy::type_complexity)]
64pub fn constraint_set_zerocheck_claim<F: TowerField>(
65	constraint_set: SizedConstraintSet<F>,
66) -> Result<(ZerocheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
67	let (constraints, meta) = split_constraint_set(constraint_set);
68	let n_multilinears = meta.oracle_ids.len();
69
70	let mut zeros = Vec::new();
71	for Constraint {
72		composition,
73		predicate,
74		..
75	} in constraints
76	{
77		match predicate {
78			ConstraintPredicate::Zero => {
79				zeros.push(ArithCircuitPoly::with_n_vars(n_multilinears, composition)?)
80			}
81			_ => bail!(Error::MixedBatchingNotSupported),
82		}
83	}
84
85	let claim = ZerocheckClaim::new(meta.n_vars, n_multilinears, zeros)?;
86	Ok((claim, meta))
87}
88
89#[allow(clippy::type_complexity)]
90pub fn constraint_set_mlecheck_claim<F: TowerField>(
91	constraint_set: SizedConstraintSet<F>,
92) -> Result<(EqIndSumcheckClaim<F, ArithCircuitPoly<F>>, OracleClaimMeta), Error> {
93	let (constraints, meta) = split_constraint_set(constraint_set);
94	let n_multilinears = meta.oracle_ids.len();
95
96	let mut sums = Vec::new();
97	for Constraint {
98		composition,
99		predicate,
100		..
101	} in constraints
102	{
103		match predicate {
104			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
105				composition: ArithCircuitPoly::with_n_vars(n_multilinears, composition)?,
106				sum,
107			}),
108			_ => bail!(Error::MixedBatchingNotSupported),
109		}
110	}
111
112	let claim = EqIndSumcheckClaim::new(meta.n_vars, n_multilinears, sums)?;
113	Ok((claim, meta))
114}
115
116fn split_constraint_set<F: Field>(
117	constraint_set: SizedConstraintSet<F>,
118) -> (Vec<Constraint<F>>, OracleClaimMeta) {
119	let SizedConstraintSet {
120		oracle_ids,
121		constraints,
122		n_vars,
123	} = constraint_set;
124	let meta = OracleClaimMeta { n_vars, oracle_ids };
125	(constraints, meta)
126}
127
128/// Constructs evalcheck claims from metadata returned by constraint set claim constructors.
129pub fn make_eval_claims<F: TowerField>(
130	evaluation_order: EvaluationOrder,
131	metas: impl IntoIterator<Item = OracleClaimMeta>,
132	batch_sumcheck_output: BatchSumcheckOutput<F>,
133) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
134	let metas = metas.into_iter().collect::<Vec<_>>();
135
136	if !is_sorted_ascending(metas.iter().map(|meta| meta.n_vars)) {
137		bail!(Error::ClaimsOutOfOrder);
138	}
139
140	let max_n_vars = metas.last().map_or(0, |meta| meta.n_vars);
141
142	if metas.len() != batch_sumcheck_output.multilinear_evals.len() {
143		bail!(Error::ClaimProofMismatch);
144	}
145
146	if max_n_vars != batch_sumcheck_output.challenges.len() {
147		bail!(Error::ClaimProofMismatch);
148	}
149
150	let mut evalcheck_claims = Vec::new();
151	for (meta, prover_evals) in iter::zip(metas, batch_sumcheck_output.multilinear_evals) {
152		if meta.oracle_ids.len() != prover_evals.len() {
153			bail!(Error::ClaimProofMismatch);
154		}
155
156		for (oracle_id, eval) in iter::zip(meta.oracle_ids, prover_evals) {
157			let eval_points_range = match evaluation_order {
158				EvaluationOrder::LowToHigh => 0..meta.n_vars,
159				EvaluationOrder::HighToLow => max_n_vars - meta.n_vars..max_n_vars,
160			};
161			let eval_point = batch_sumcheck_output.challenges[eval_points_range].to_vec();
162
163			let claim = EvalcheckMultilinearClaim {
164				id: oracle_id,
165				eval_point: eval_point.into(),
166				eval,
167			};
168
169			evalcheck_claims.push(claim);
170		}
171	}
172
173	Ok(evalcheck_claims)
174}
175
176/// Construct eval claims from the batched zerocheck output.
177pub fn make_zerocheck_eval_claims<F: Field>(
178	metas: impl IntoIterator<Item = OracleClaimMeta>,
179	batch_zerocheck_output: BatchZerocheckOutput<F>,
180) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
181	let BatchZerocheckOutput {
182		skipped_challenges,
183		unskipped_challenges,
184		concat_multilinear_evals,
185	} = batch_zerocheck_output;
186
187	let metas = metas.into_iter().collect::<Vec<_>>();
188
189	if !is_sorted_ascending(metas.iter().map(|meta| meta.n_vars)) {
190		bail!(Error::ClaimsOutOfOrder);
191	}
192
193	let max_n_vars = metas.last().map_or(0, |meta| meta.n_vars);
194	let n_multilinears = metas
195		.iter()
196		.map(|meta| meta.oracle_ids.len())
197		.sum::<usize>();
198
199	if n_multilinears != concat_multilinear_evals.len() {
200		bail!(Error::ClaimProofMismatch);
201	}
202
203	if max_n_vars != skipped_challenges.len() + unskipped_challenges.len() {
204		bail!(Error::ClaimProofMismatch);
205	}
206
207	let ids_with_n_vars = metas.into_iter().flat_map(|meta| {
208		meta.oracle_ids
209			.into_iter()
210			.map(move |oracle_id| (oracle_id, meta.n_vars))
211	});
212
213	let mut evalcheck_claims = Vec::new();
214	for ((oracle_id, n_vars), eval) in iter::zip(ids_with_n_vars, concat_multilinear_evals) {
215		// NB. Two stages of zerocheck reduction (univariate skip and front-loaded high-to-low
216		// sumchecks)     may result in a "gap" between challenges prefix and suffix.
217		let eval_point = [
218			&skipped_challenges[..n_vars.min(skipped_challenges.len())],
219			&unskipped_challenges[(max_n_vars - n_vars).min(unskipped_challenges.len())..],
220		]
221		.concat()
222		.into();
223
224		let claim = EvalcheckMultilinearClaim {
225			id: oracle_id,
226			eval_point,
227			eval,
228		};
229
230		evalcheck_claims.push(claim);
231	}
232
233	Ok(evalcheck_claims)
234}
235
236pub struct SumcheckClaimsWithMeta<F: TowerField, C> {
237	pub claims: Vec<SumcheckClaim<F, C>>,
238	pub metas: Vec<OracleClaimMeta>,
239}
240
241/// Constructs sumcheck claims and metas from the vector of [`SizedConstraintSet`]
242pub fn constraint_set_sumcheck_claims<F: TowerField>(
243	constraint_sets: Vec<SizedConstraintSet<F>>,
244) -> Result<SumcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
245	let mut claims = Vec::with_capacity(constraint_sets.len());
246	let mut metas = Vec::with_capacity(constraint_sets.len());
247
248	for constraint_set in constraint_sets {
249		let (claim, meta) = constraint_set_sumcheck_claim(constraint_set)?;
250		metas.push(meta);
251		claims.push(claim);
252	}
253	Ok(SumcheckClaimsWithMeta { claims, metas })
254}
255
256pub struct MLEcheckClaimsWithMeta<F: TowerField, C> {
257	pub claims: Vec<EqIndSumcheckClaim<F, C>>,
258	pub metas: Vec<OracleClaimMeta>,
259}
260
261/// Constructs sumcheck claims and metas from the vector of [`SizedConstraintSet`]
262pub fn constraint_set_mlecheck_claims<F: TowerField>(
263	constraint_sets: Vec<SizedConstraintSet<F>>,
264) -> Result<MLEcheckClaimsWithMeta<F, ArithCircuitPoly<F>>, Error> {
265	let mut claims = Vec::with_capacity(constraint_sets.len());
266	let mut metas = Vec::with_capacity(constraint_sets.len());
267
268	for constraint_set in constraint_sets {
269		let (claim, meta) = constraint_set_mlecheck_claim(constraint_set)?;
270		metas.push(meta);
271		claims.push(claim);
272	}
273	Ok(MLEcheckClaimsWithMeta { claims, metas })
274}