binius_core/protocols/sumcheck/prove/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackend;
6use binius_math::{EvaluationDomainFactory, EvaluationOrder, MultilinearPoly};
7use binius_utils::bail;
8
9use super::{
10	RegularSumcheckProver, ZerocheckProverImpl,
11	eq_ind::{EqIndSumcheckProver, EqIndSumcheckProverBuilder},
12};
13use crate::{
14	oracle::{Constraint, ConstraintPredicate, SizedConstraintSet},
15	protocols::{
16		evalcheck::{EvalPoint, subclaims::MemoizedData},
17		sumcheck::{
18			CompositeSumClaim, Error, OracleClaimMeta, constraint_set_mlecheck_claim,
19			constraint_set_sumcheck_claim,
20		},
21	},
22	witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
23};
24
25pub type OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend> = ZerocheckProverImpl<
26	'a,
27	FDomain,
28	FBase,
29	P,
30	ArithCircuitPoly<FBase>,
31	ArithCircuitPoly<<P as PackedField>::Scalar>,
32	MultilinearWitness<'a, P>,
33	DomainFactory,
34	Backend,
35>;
36
37pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
38	'a,
39	FDomain,
40	P,
41	ArithCircuitPoly<<P as PackedField>::Scalar>,
42	MultilinearWitness<'a, P>,
43	Backend,
44>;
45
46pub type OracleMLECheckProver<'a, FDomain, P, Backend> = EqIndSumcheckProver<
47	'a,
48	FDomain,
49	P,
50	ArithCircuitPoly<<P as PackedField>::Scalar>,
51	MultilinearWitness<'a, P>,
52	Backend,
53>;
54
55/// Construct zerocheck prover from the constraint set. Fails when constraint set contains regular
56/// sumchecks.
57pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, DomainFactory, Backend>(
58	constraints: Vec<Constraint<P::Scalar>>,
59	multilinears: Vec<MultilinearWitness<'a, P>>,
60	domain_factory: DomainFactory,
61	zerocheck_challenges: &[F],
62	backend: &'a Backend,
63) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend>, Error>
64where
65	P: PackedField<Scalar = F>
66		+ PackedExtension<F, PackedSubfield = P>
67		+ PackedExtension<FDomain>
68		+ PackedExtension<FBase>,
69	F: TowerField,
70	FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
71	FDomain: Field,
72	DomainFactory: EvaluationDomainFactory<FDomain>,
73	Backend: ComputationBackend,
74{
75	let mut zeros = Vec::with_capacity(constraints.len());
76
77	for Constraint {
78		composition,
79		predicate,
80		name,
81	} in constraints
82	{
83		let composition_base = composition
84			.try_convert_field::<FBase>()
85			.map_err(|_| Error::CircuitFieldDowncastFailed)?;
86		match predicate {
87			ConstraintPredicate::Zero => {
88				zeros.push((
89					name,
90					ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
91					ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
92				));
93			}
94			_ => bail!(Error::MixedBatchingNotSupported),
95		}
96	}
97
98	let prover = OracleZerocheckProver::<_, _, FDomain, _, _>::new(
99		multilinears,
100		zeros,
101		zerocheck_challenges,
102		domain_factory,
103		backend,
104	)?;
105
106	Ok(prover)
107}
108
109/// Construct regular sumcheck prover from the constraint set. Fails when constraint set contains
110/// zerochecks.
111pub fn constraint_set_sumcheck_prover<'a, F, P, FDomain, Backend>(
112	evaluation_order: EvaluationOrder,
113	constraint_set: SizedConstraintSet<F>,
114	witness: &MultilinearExtensionIndex<'a, P>,
115	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
116	switchover_fn: impl Fn(usize) -> usize + Clone,
117	backend: &'a Backend,
118) -> Result<OracleSumcheckProver<'a, FDomain, P, Backend>, Error>
119where
120	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
121	F: TowerField + ExtensionField<FDomain>,
122	FDomain: Field,
123	Backend: ComputationBackend,
124{
125	let (constraints, multilinears) = split_constraint_set::<F, P>(constraint_set, witness)?;
126
127	let mut sums = Vec::new();
128
129	for Constraint {
130		composition,
131		predicate,
132		..
133	} in constraints
134	{
135		match predicate {
136			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
137				composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
138				sum,
139			}),
140			_ => bail!(Error::MixedBatchingNotSupported),
141		}
142	}
143
144	let prover = RegularSumcheckProver::new(
145		evaluation_order,
146		multilinears,
147		sums,
148		evaluation_domain_factory,
149		switchover_fn,
150		backend,
151	)?;
152
153	Ok(prover)
154}
155
156/// Construct mlecheck prover from the constraint set. Fails when constraint set contains
157/// zerochecks.
158#[allow(clippy::too_many_arguments)]
159pub fn constraint_set_mlecheck_prover<'a, 'b, F, P, FDomain, Backend>(
160	evaluation_order: EvaluationOrder,
161	constraint_set: SizedConstraintSet<F>,
162	eq_ind_challenges: &[F],
163	memoized_data: &mut MemoizedData<'b, P>,
164	witness: &MultilinearExtensionIndex<'a, P>,
165	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
166	switchover_fn: impl Fn(usize) -> usize + Clone,
167	backend: &'a Backend,
168) -> Result<OracleMLECheckProver<'a, FDomain, P, Backend>, Error>
169where
170	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
171	F: TowerField + ExtensionField<FDomain>,
172	FDomain: Field,
173	Backend: ComputationBackend,
174{
175	let SizedConstraintSet {
176		oracle_ids,
177		constraints,
178		n_vars,
179	} = constraint_set;
180
181	let mut multilinears = Vec::with_capacity(oracle_ids.len());
182	let mut const_suffixes = Vec::with_capacity(oracle_ids.len());
183
184	for id in oracle_ids {
185		let IndexEntry {
186			multilin_poly,
187			nonzero_scalars_prefix,
188		} = witness.get_index_entry(id)?;
189
190		if multilin_poly.n_vars() != n_vars {
191			bail!(Error::ConstraintSetNumberOfVariablesMismatch);
192		}
193
194		multilinears.push(multilin_poly);
195		const_suffixes.push((F::ZERO, ((1 << n_vars) - nonzero_scalars_prefix)))
196	}
197
198	let mut sums = Vec::new();
199
200	for Constraint {
201		composition,
202		predicate,
203		..
204	} in constraints
205	{
206		match predicate {
207			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
208				composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
209				sum,
210			}),
211			_ => bail!(Error::MixedBatchingNotSupported),
212		}
213	}
214
215	let n_vars = eq_ind_challenges.len();
216
217	let eq_ind_partial = match evaluation_order {
218		EvaluationOrder::LowToHigh => &eq_ind_challenges[n_vars.min(1)..],
219		EvaluationOrder::HighToLow => &eq_ind_challenges[..n_vars.saturating_sub(1)],
220	};
221
222	let eq_ind_partial_evals = memoized_data
223		.full_query(eq_ind_partial)?
224		.expansion()
225		.to_vec();
226
227	let prover = EqIndSumcheckProverBuilder::with_switchover(multilinears, switchover_fn, backend)?
228		.with_eq_ind_partial_evals(Backend::to_hal_slice(eq_ind_partial_evals))
229		.with_const_suffixes(&const_suffixes)?
230		.build(evaluation_order, eq_ind_challenges, sums, evaluation_domain_factory)?;
231
232	Ok(prover)
233}
234
235type ConstraintsAndMultilinears<'a, F, P> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, P>>);
236
237#[allow(clippy::type_complexity)]
238pub fn split_constraint_set<'a, F, P>(
239	constraint_set: SizedConstraintSet<F>,
240	witness: &MultilinearExtensionIndex<'a, P>,
241) -> Result<ConstraintsAndMultilinears<'a, F, P>, Error>
242where
243	F: Field,
244	P: PackedField,
245	P::Scalar: ExtensionField<F>,
246{
247	let SizedConstraintSet {
248		oracle_ids,
249		constraints,
250		n_vars,
251	} = constraint_set;
252
253	let multilinears = oracle_ids
254		.iter()
255		.map(|&oracle_id| witness.get_multilin_poly(oracle_id))
256		.collect::<Result<Vec<_>, _>>()?;
257
258	if multilinears
259		.iter()
260		.any(|multilin| multilin.n_vars() != n_vars)
261	{
262		bail!(Error::ConstraintSetNumberOfVariablesMismatch);
263	}
264
265	Ok((constraints, multilinears))
266}
267
268pub struct SumcheckProversWithMetas<'a, P, FDomain, Backend>
269where
270	P: PackedField,
271	FDomain: Field,
272	Backend: ComputationBackend,
273{
274	pub provers: Vec<OracleSumcheckProver<'a, FDomain, P, Backend>>,
275	pub metas: Vec<OracleClaimMeta>,
276}
277
278/// Constructs sumcheck provers and metas from the vector of [`SizedConstraintSet`]
279pub fn constraint_sets_sumcheck_provers_metas<'a, P, FDomain, Backend>(
280	evaluation_order: EvaluationOrder,
281	constraint_sets: Vec<SizedConstraintSet<P::Scalar>>,
282	witness: &MultilinearExtensionIndex<'a, P>,
283	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
284	switchover_fn: impl Fn(usize) -> usize,
285	backend: &'a Backend,
286) -> Result<SumcheckProversWithMetas<'a, P, FDomain, Backend>, Error>
287where
288	P: PackedExtension<FDomain>,
289	P::Scalar: TowerField + ExtensionField<FDomain>,
290	FDomain: Field,
291	Backend: ComputationBackend,
292{
293	let mut provers = Vec::with_capacity(constraint_sets.len());
294	let mut metas = Vec::with_capacity(constraint_sets.len());
295
296	for constraint_set in constraint_sets {
297		let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
298		let prover = constraint_set_sumcheck_prover(
299			evaluation_order,
300			constraint_set,
301			witness,
302			evaluation_domain_factory.clone(),
303			&switchover_fn,
304			backend,
305		)?;
306		metas.push(meta);
307		provers.push(prover);
308	}
309	Ok(SumcheckProversWithMetas { provers, metas })
310}
311
312pub struct MLECheckProverWithMeta<'a, P, FDomain, Backend>
313where
314	P: PackedField,
315	FDomain: Field,
316	Backend: ComputationBackend,
317{
318	pub prover: OracleMLECheckProver<'a, FDomain, P, Backend>,
319	pub meta: OracleClaimMeta,
320}
321
322/// Constructs sumcheck provers and metas from the vector of [`SizedConstraintSet`]
323#[allow(clippy::too_many_arguments)]
324pub fn constraint_sets_mlecheck_prover_meta<'a, 'b, P, FDomain, Backend>(
325	evaluation_order: EvaluationOrder,
326	constraint_set: SizedConstraintSet<P::Scalar>,
327	eq_ind_challenges: EvalPoint<P::Scalar>,
328	memoized_data: &mut MemoizedData<'b, P>,
329	witness: &MultilinearExtensionIndex<'a, P>,
330	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
331	switchover_fn: impl Fn(usize) -> usize,
332	backend: &'a Backend,
333) -> Result<MLECheckProverWithMeta<'a, P, FDomain, Backend>, Error>
334where
335	P: PackedExtension<FDomain>,
336	P::Scalar: TowerField + ExtensionField<FDomain>,
337	FDomain: Field,
338	Backend: ComputationBackend,
339{
340	let (_, meta) = constraint_set_mlecheck_claim(constraint_set.clone())?;
341	let prover = constraint_set_mlecheck_prover(
342		evaluation_order,
343		constraint_set,
344		&eq_ind_challenges,
345		memoized_data,
346		witness,
347		evaluation_domain_factory,
348		&switchover_fn,
349		backend,
350	)?;
351
352	Ok(MLECheckProverWithMeta { prover, meta })
353}