binius_core/protocols/sumcheck/prove/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::bail;
7
8use super::{RegularSumcheckProver, ZerocheckProverImpl};
9use crate::{
10	oracle::{Constraint, ConstraintPredicate, ConstraintSet},
11	polynomial::ArithCircuitPoly,
12	protocols::sumcheck::{
13		constraint_set_sumcheck_claim, CompositeSumClaim, Error, OracleClaimMeta,
14	},
15	witness::{MultilinearExtensionIndex, MultilinearWitness},
16};
17
18pub type OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend> = ZerocheckProverImpl<
19	'a,
20	FDomain,
21	FBase,
22	P,
23	ArithCircuitPoly<FBase>,
24	ArithCircuitPoly<<P as PackedField>::Scalar>,
25	MultilinearWitness<'a, P>,
26	DomainFactory,
27	Backend,
28>;
29
30pub type OracleSumcheckProver<'a, FDomain, P, Backend> = RegularSumcheckProver<
31	'a,
32	FDomain,
33	P,
34	ArithCircuitPoly<<P as PackedField>::Scalar>,
35	MultilinearWitness<'a, P>,
36	Backend,
37>;
38
39/// Construct zerocheck prover from the constraint set. Fails when constraint set contains regular sumchecks.
40pub fn constraint_set_zerocheck_prover<'a, P, F, FBase, FDomain, DomainFactory, Backend>(
41	constraints: Vec<Constraint<P::Scalar>>,
42	multilinears: Vec<MultilinearWitness<'a, P>>,
43	domain_factory: DomainFactory,
44	zerocheck_challenges: &[F],
45	backend: &'a Backend,
46) -> Result<OracleZerocheckProver<'a, P, FBase, FDomain, DomainFactory, Backend>, Error>
47where
48	P: PackedField<Scalar = F>
49		+ PackedExtension<F, PackedSubfield = P>
50		+ PackedExtension<FDomain>
51		+ PackedExtension<FBase>,
52	F: TowerField,
53	FBase: TowerField + ExtensionField<FDomain> + TryFrom<P::Scalar>,
54	FDomain: Field,
55	DomainFactory: EvaluationDomainFactory<FDomain>,
56	Backend: ComputationBackend,
57{
58	let mut zeros = Vec::with_capacity(constraints.len());
59
60	for Constraint {
61		composition,
62		predicate,
63		name,
64	} in constraints
65	{
66		let composition_base = composition
67			.try_convert_field::<FBase>()
68			.map_err(|_| Error::CircuitFieldDowncastFailed)?;
69		match predicate {
70			ConstraintPredicate::Zero => {
71				zeros.push((
72					name,
73					ArithCircuitPoly::with_n_vars(multilinears.len(), composition_base)?,
74					ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
75				));
76			}
77			_ => bail!(Error::MixedBatchingNotSupported),
78		}
79	}
80
81	let prover = OracleZerocheckProver::<_, _, FDomain, _, _>::new(
82		multilinears,
83		zeros,
84		zerocheck_challenges,
85		domain_factory,
86		backend,
87	)?;
88
89	Ok(prover)
90}
91
92/// Construct regular sumcheck prover from the constraint set. Fails when constraint set contains zerochecks.
93pub fn constraint_set_sumcheck_prover<'a, FW, PW, FDomain, Backend>(
94	evaluation_order: EvaluationOrder,
95	constraint_set: ConstraintSet<FW>,
96	witness: &MultilinearExtensionIndex<'a, PW>,
97	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
98	switchover_fn: impl Fn(usize) -> usize + Clone,
99	backend: &'a Backend,
100) -> Result<OracleSumcheckProver<'a, FDomain, PW, Backend>, Error>
101where
102	PW: PackedField<Scalar = FW> + PackedExtension<FDomain>,
103	FW: TowerField + ExtensionField<FDomain>,
104	FDomain: Field,
105	Backend: ComputationBackend,
106{
107	let (constraints, multilinears) = split_constraint_set::<FW, PW>(constraint_set, witness)?;
108
109	let mut sums = Vec::new();
110
111	for Constraint {
112		composition,
113		predicate,
114		..
115	} in constraints
116	{
117		match predicate {
118			ConstraintPredicate::Sum(sum) => sums.push(CompositeSumClaim {
119				composition: ArithCircuitPoly::with_n_vars(multilinears.len(), composition)?,
120				sum,
121			}),
122			_ => bail!(Error::MixedBatchingNotSupported),
123		}
124	}
125
126	let prover = RegularSumcheckProver::new(
127		evaluation_order,
128		multilinears,
129		sums,
130		evaluation_domain_factory,
131		switchover_fn,
132		backend,
133	)?;
134
135	Ok(prover)
136}
137
138type ConstraintsAndMultilinears<'a, F, PW> = (Vec<Constraint<F>>, Vec<MultilinearWitness<'a, PW>>);
139
140#[allow(clippy::type_complexity)]
141pub fn split_constraint_set<'a, F, PW>(
142	constraint_set: ConstraintSet<F>,
143	witness: &MultilinearExtensionIndex<'a, PW>,
144) -> Result<ConstraintsAndMultilinears<'a, F, PW>, Error>
145where
146	F: Field,
147	PW: PackedField,
148	PW::Scalar: ExtensionField<F>,
149{
150	let ConstraintSet {
151		oracle_ids,
152		constraints,
153		n_vars,
154	} = constraint_set;
155
156	let multilinears = oracle_ids
157		.iter()
158		.map(|&oracle_id| witness.get_multilin_poly(oracle_id))
159		.collect::<Result<Vec<_>, _>>()?;
160
161	if multilinears
162		.iter()
163		.any(|multilin| multilin.n_vars() != n_vars)
164	{
165		bail!(Error::ConstraintSetNumberOfVariablesMismatch);
166	}
167
168	Ok((constraints, multilinears))
169}
170
171pub struct SumcheckProversWithMetas<'a, PW, FDomain, Backend>
172where
173	PW: PackedField,
174	FDomain: Field,
175	Backend: ComputationBackend,
176{
177	pub provers: Vec<OracleSumcheckProver<'a, FDomain, PW, Backend>>,
178	pub metas: Vec<OracleClaimMeta>,
179}
180
181/// Constructs sumcheck provers and metas from the vector of [`ConstraintSet`]
182pub fn constraint_sets_sumcheck_provers_metas<'a, PW, FDomain, Backend>(
183	evaluation_order: EvaluationOrder,
184	constraint_sets: Vec<ConstraintSet<PW::Scalar>>,
185	witness: &MultilinearExtensionIndex<'a, PW>,
186	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
187	switchover_fn: impl Fn(usize) -> usize,
188	backend: &'a Backend,
189) -> Result<SumcheckProversWithMetas<'a, PW, FDomain, Backend>, Error>
190where
191	PW: PackedExtension<FDomain>,
192	PW::Scalar: TowerField + ExtensionField<FDomain>,
193	FDomain: Field,
194	Backend: ComputationBackend,
195{
196	let mut provers = Vec::with_capacity(constraint_sets.len());
197	let mut metas = Vec::with_capacity(constraint_sets.len());
198
199	for constraint_set in constraint_sets {
200		let (_, meta) = constraint_set_sumcheck_claim(constraint_set.clone())?;
201		let prover = constraint_set_sumcheck_prover(
202			evaluation_order,
203			constraint_set,
204			witness,
205			evaluation_domain_factory.clone(),
206			&switchover_fn,
207			backend,
208		)?;
209		metas.push(meta);
210		provers.push(prover);
211	}
212	Ok(SumcheckProversWithMetas { provers, metas })
213}