binius_core/protocols/sumcheck/prove/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	as_packed_field::{PackScalar, PackedType},
5	underlier::UnderlierType,
6	ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{EvaluationDomainFactory, EvaluationOrder};
10use binius_utils::bail;
11
12use super::{RegularSumcheckProver, UnivariateZerocheck};
13use crate::{
14	oracle::{Constraint, ConstraintPredicate, ConstraintSet},
15	polynomial::ArithCircuitPoly,
16	protocols::sumcheck::{
17		constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
18	},
19	witness::{MultilinearExtensionIndex, MultilinearWitness},
20};
21
22pub type OracleZerocheckProver<'a, P, FBase, FDomain, Backend> = UnivariateZerocheck<
23	'a,
24	'a,
25	FDomain,
26	FBase,
27	P,
28	ArithCircuitPoly<FBase>,
29	ArithCircuitPoly<<P as PackedField>::Scalar>,
30	MultilinearWitness<'a, P>,
31	Backend,
32>;
33
34pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
35	'a,
36	FDomain,
37	P,
38	ArithCircuitPoly<<P as PackedField>::Scalar>,
39	MultilinearWitness<'a, P>,
40	Backend,
41>;
42
43/// Construct zerocheck prover from the constraint set. Fails when constraint set contains regular sumchecks.
44pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, Backend>(
45	constraints: Vec<Constraint<P::Scalar>>,
46	multilinears: Vec<MultilinearWitness<'a, P>>,
47	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
48	switchover_fn: impl Fn(usize) -> usize + Clone,
49	zerocheck_challenges: &[P::Scalar],
50	backend: &'a Backend,
51) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, Backend>, Error>
52where
53	P: PackedFieldIndexable<Scalar = F>
54		+ PackedExtension<F, PackedSubfield = P>
55		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
56		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
57	F: TowerField,
58	FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
59	FDomain: Field,
60	Backend: ComputationBackend,
61{
62	let mut zeros = Vec::with_capacity(constraints.len());
63
64	for Constraint {
65		composition,
66		predicate,
67		name,
68	} in constraints
69	{
70		let composition_base = composition
71			.clone()
72			.try_convert_field::<FBase>()
73			.map_err(|_| Error::CircuitFieldDowncastFailed)?;
74		match predicate {
75			ConstraintPredicate::Zero => {
76				zeros.push((
77					name,
78					ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
79					ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
80				));
81			}
82			_ => bail!(Error::MixedBatchingNotSupported),
83		}
84	}
85
86	let prover = OracleZerocheckProver::<_, _, FDomain, _>::new(
87		multilinears,
88		zeros,
89		zerocheck_challenges,
90		evaluation_domain_factory,
91		switchover_fn,
92		backend,
93	)?;
94
95	Ok(prover)
96}
97
98/// Construct regular sumcheck prover from the constraint set. Fails when constraint set contains zerochecks.
99pub fn constraint_set_sumcheck_prover<'a, U, FW, FDomain, Backend>(
100	evaluation_order: EvaluationOrder,
101	constraint_set: ConstraintSet<FW>,
102	witness: &MultilinearExtensionIndex<'a, U, FW>,
103	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
104	switchover_fn: impl Fn(usize) -> usize + Clone,
105	backend: &'a Backend,
106) -> Result<OracleSumcheckProver<'a, FDomain, PackedType<U, FW>, Backend>, Error>
107where
108	U: UnderlierType + PackScalar<FW> + PackScalar<FDomain>,
109	FW: TowerField + ExtensionField<FDomain>,
110	FDomain: Field,
111	Backend: ComputationBackend,
112{
113	let (constraints, multilinears) = split_constraint_set::<_, FW, _>(constraint_set, witness)?;
114
115	let mut sums = Vec::new();
116
117	for Constraint {
118		composition,
119		predicate,
120		..
121	} in constraints
122	{
123		match predicate {
124			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
125				composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
126				sum,
127			}),
128			_ => bail!(Error::MixedBatchingNotSupported),
129		}
130	}
131
132	let prover = RegularSumcheckProver::new(
133		evaluation_order,
134		multilinears,
135		sums,
136		evaluation_domain_factory,
137		switchover_fn,
138		backend,
139	)?;
140
141	Ok(prover)
142}
143
144type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
145
146#[allow(clippy::type_complexity)]
147pub fn split_constraint_set<'a, U, F, FW>(
148	constraint_set: ConstraintSet<F>,
149	witness: &MultilinearExtensionIndex<'a, U, FW>,
150) -> Result<ConstraintsAndMultilinears<'a, F, PackedType<U, FW>>, Error>
151where
152	U: UnderlierType + PackScalar<F> + PackScalar<FW>,
153	F: Field,
154	FW: ExtensionField<F>,
155{
156	let ConstraintSet {
157		oracle_ids,
158		constraints,
159		n_vars,
160	} = constraint_set;
161
162	let multilinears = oracle_ids
163		.iter()
164		.map(|&oracle_id| witness.get_multilin_poly(oracle_id))
165		.collect::<Result<Vec<_>, _>>()?;
166
167	if multilinears
168		.iter()
169		.any(|multilin| multilin.n_vars() != n_vars)
170	{
171		bail!(Error::ConstraintSetNumberOfVariablesMismatch);
172	}
173
174	Ok((constraints, multilinears))
175}
176
177pub struct SumcheckProversWithMetas<'a, U, FW, FDomain, Backend>
178where
179	U: UnderlierType + PackScalar<FW>,
180	FW: TowerField,
181	FDomain: Field,
182	Backend: ComputationBackend,
183{
184	pub provers: Vec<OracleSumcheckProver<'a, FDomain, PackedType<U, FW>, Backend>>,
185	pub metas: Vec<OracleClaimMeta>,
186}
187
188/// Constructs sumcheck provers and metas from the vector of [`ConstraintSet`]
189pub fn constraint_sets_sumcheck_provers_metas<'a, U, FW, FDomain, Backend>(
190	constraint_sets: Vec<ConstraintSet<FW>>,
191	witness: &MultilinearExtensionIndex<'a, U, FW>,
192	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
193	switchover_fn: impl Fn(usize) -> usize,
194	backend: &'a Backend,
195) -> Result<SumcheckProversWithMetas<'a, U, FW, FDomain, Backend>, Error>
196where
197	U: UnderlierType + PackScalar<FW> + PackScalar<FDomain>,
198	FW: TowerField + ExtensionField<FDomain>,
199	FDomain: Field,
200	Backend: ComputationBackend,
201{
202	let mut provers = Vec::with_capacity(constraint_sets.len());
203	let mut metas = Vec::with_capacity(constraint_sets.len());
204
205	for constraint_set in constraint_sets {
206		let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
207		let prover = constraint_set_sumcheck_prover(
208			EvaluationOrder::LowToHigh,
209			constraint_set,
210			witness,
211			evaluation_domain_factory.clone(),
212			&switchover_fn,
213			backend,
214		)?;
215		metas.push(meta);
216		provers.push(prover);
217	}
218	Ok(SumcheckProversWithMetas { provers, metas })
219}