binius_core/protocols/sumcheck/prove/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
5};
6use binius_hal::ComputationBackend;
7use binius_math::{EvaluationDomainFactory, EvaluationOrder};
8use binius_utils::bail;
9
10use super::{RegularSumcheckProver, UnivariateZerocheck};
11use crate::{
12	oracle::{Constraint, ConstraintPredicate, ConstraintSet},
13	polynomial::ArithCircuitPoly,
14	protocols::sumcheck::{
15		constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
16	},
17	witness::{MultilinearExtensionIndex, MultilinearWitness},
18};
19
20pub type OracleZerocheckProver<
21	'a,
22	P,
23	FBase,
24	FDomain,
25	InterpolationDomainFactory,
26	SwitchoverFn,
27	Backend,
28> = UnivariateZerocheck<
29	'a,
30	FDomain,
31	FBase,
32	P,
33	ArithCircuitPoly<FBase>,
34	ArithCircuitPoly<<P as PackedField>::Scalar>,
35	MultilinearWitness<'a, P>,
36	InterpolationDomainFactory,
37	SwitchoverFn,
38	Backend,
39>;
40
41pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
42	'a,
43	FDomain,
44	P,
45	ArithCircuitPoly<<P as PackedField>::Scalar>,
46	MultilinearWitness<'a, P>,
47	Backend,
48>;
49
50/// Construct zerocheck prover from the constraint set. Fails when constraint set contains regular sumchecks.
51pub fn constraint_set_zerocheck_prover<
52	'a,
53	P,
54	F,
55	FBase,
56	FDomain,
57	InterpolationDomainFactory,
58	SwitchoverFn,
59	Backend,
60>(
61	constraints: Vec<Constraint<P::Scalar>>,
62	multilinears: Vec<MultilinearWitness<'a, P>>,
63	interpolation_domain_factory: InterpolationDomainFactory,
64	switchover_fn: SwitchoverFn,
65	zerocheck_challenges: &[F],
66	backend: &'a Backend,
67) -> Result<
68	OracleZerocheckProver<'a, P, FBase, FDomain, InterpolationDomainFactory, SwitchoverFn, Backend>,
69	Error,
70>
71where
72	P: PackedFieldIndexable<Scalar = F>
73		+ PackedExtension<F, PackedSubfield = P>
74		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
75		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
76	F: TowerField,
77	FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
78	FDomain: Field,
79	InterpolationDomainFactory: EvaluationDomainFactory<FDomain>,
80	SwitchoverFn: Fn(usize) -> usize + Clone,
81	Backend: ComputationBackend,
82{
83	let mut zeros = Vec::with_capacity(constraints.len());
84
85	for Constraint {
86		composition,
87		predicate,
88		name,
89	} in constraints
90	{
91		let composition_base = composition
92			.clone()
93			.try_convert_field::<FBase>()
94			.map_err(|_| Error::CircuitFieldDowncastFailed)?;
95		match predicate {
96			ConstraintPredicate::Zero => {
97				zeros.push((
98					name,
99					ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
100					ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
101				));
102			}
103			_ => bail!(Error::MixedBatchingNotSupported),
104		}
105	}
106
107	let prover = OracleZerocheckProver::<_, _, FDomain, _, _, _>::new(
108		multilinears,
109		zeros,
110		zerocheck_challenges,
111		interpolation_domain_factory,
112		switchover_fn,
113		backend,
114	)?;
115
116	Ok(prover)
117}
118
119/// Construct regular sumcheck prover from the constraint set. Fails when constraint set contains zerochecks.
120pub fn constraint_set_sumcheck_prover<'a, FW, PW, FDomain, Backend>(
121	evaluation_order: EvaluationOrder,
122	constraint_set: ConstraintSet<FW>,
123	witness: &MultilinearExtensionIndex<'a, PW>,
124	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
125	switchover_fn: impl Fn(usize) -> usize + Clone,
126	backend: &'a Backend,
127) -> Result<OracleSumcheckProver<'a, FDomain, PW, Backend>, Error>
128where
129	PW: PackedField<Scalar = FW> + PackedExtension<FDomain>,
130	FW: TowerField + ExtensionField<FDomain>,
131	FDomain: Field,
132	Backend: ComputationBackend,
133{
134	let (constraints, multilinears) = split_constraint_set::<FW, PW>(constraint_set, witness)?;
135
136	let mut sums = Vec::new();
137
138	for Constraint {
139		composition,
140		predicate,
141		..
142	} in constraints
143	{
144		match predicate {
145			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
146				composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
147				sum,
148			}),
149			_ => bail!(Error::MixedBatchingNotSupported),
150		}
151	}
152
153	let prover = RegularSumcheckProver::new(
154		evaluation_order,
155		multilinears,
156		sums,
157		evaluation_domain_factory,
158		switchover_fn,
159		backend,
160	)?;
161
162	Ok(prover)
163}
164
165type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
166
167#[allow(clippy::type_complexity)]
168pub fn split_constraint_set<'a, F, PW>(
169	constraint_set: ConstraintSet<F>,
170	witness: &MultilinearExtensionIndex<'a, PW>,
171) -> Result<ConstraintsAndMultilinears<'a, F, PW>, Error>
172where
173	F: Field,
174	PW: PackedField,
175	PW::Scalar: ExtensionField<F>,
176{
177	let ConstraintSet {
178		oracle_ids,
179		constraints,
180		n_vars,
181	} = constraint_set;
182
183	let multilinears = oracle_ids
184		.iter()
185		.map(|&oracle_id| witness.get_multilin_poly(oracle_id))
186		.collect::<Result<Vec<_>, _>>()?;
187
188	if multilinears
189		.iter()
190		.any(|multilin| multilin.n_vars() != n_vars)
191	{
192		bail!(Error::ConstraintSetNumberOfVariablesMismatch);
193	}
194
195	Ok((constraints, multilinears))
196}
197
198pub struct SumcheckProversWithMetas<'a, PW, FDomain, Backend>
199where
200	PW: PackedField,
201	FDomain: Field,
202	Backend: ComputationBackend,
203{
204	pub provers: Vec<OracleSumcheckProver<'a, FDomain, PW, Backend>>,
205	pub metas: Vec<OracleClaimMeta>,
206}
207
208/// Constructs sumcheck provers and metas from the vector of [`ConstraintSet`]
209pub fn constraint_sets_sumcheck_provers_metas<'a, PW, FDomain, Backend>(
210	evaluation_order: EvaluationOrder,
211	constraint_sets: Vec<ConstraintSet<PW::Scalar>>,
212	witness: &MultilinearExtensionIndex<'a, PW>,
213	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
214	switchover_fn: impl Fn(usize) -> usize,
215	backend: &'a Backend,
216) -> Result<SumcheckProversWithMetas<'a, PW, FDomain, Backend>, Error>
217where
218	PW: PackedExtension<FDomain>,
219	PW::Scalar: TowerField + ExtensionField<FDomain>,
220	FDomain: Field,
221	Backend: ComputationBackend,
222{
223	let mut provers = Vec::with_capacity(constraint_sets.len());
224	let mut metas = Vec::with_capacity(constraint_sets.len());
225
226	for constraint_set in constraint_sets {
227		let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
228		let prover = constraint_set_sumcheck_prover(
229			evaluation_order,
230			constraint_set,
231			witness,
232			evaluation_domain_factory.clone(),
233			&switchover_fn,
234			backend,
235		)?;
236		metas.push(meta);
237		provers.push(prover);
238	}
239	Ok(SumcheckProversWithMetas { provers, metas })
240}