binius_core/constraint_system/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	tower::{PackedTop, TowerFamily, TowerUnderlier},
5	BinaryField, PackedField, TowerField,
6};
7use binius_hash::PseudoCompressionFunction;
8use binius_math::{ArithExpr, CompositionPoly, EvaluationOrder};
9use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
10use digest::{core_api::BlockSizeUser, Digest, Output};
11use itertools::{chain, Itertools};
12use tracing::instrument;
13
14use super::{
15	channel::{Boundary, OracleOrConst},
16	error::{Error, VerificationError},
17	exp, ConstraintSystem, Proof,
18};
19use crate::{
20	constraint_system::{
21		channel::{Flush, FlushDirection},
22		common::{FDomain, FEncode, FExt},
23	},
24	fiat_shamir::{CanSample, Challenger},
25	merkle_tree::BinaryMerkleTreeScheme,
26	oracle::{MultilinearOracleSet, OracleId},
27	piop,
28	protocols::{
29		gkr_exp,
30		gkr_gpa::{self},
31		greedy_evalcheck,
32		sumcheck::{self, constraint_set_zerocheck_claim, ZerocheckClaim},
33	},
34	ring_switch,
35	transcript::VerifierTranscript,
36};
37
38/// Verifies a proof against a constraint system.
39#[instrument("constraint_system::verify", skip_all, level = "debug")]
40pub fn verify<U, Tower, Hash, Compress, Challenger_>(
41	constraint_system: &ConstraintSystem<FExt<Tower>>,
42	log_inv_rate: usize,
43	security_bits: usize,
44	boundaries: &[Boundary<FExt<Tower>>],
45	proof: Proof,
46) -> Result<(), Error>
47where
48	U: TowerUnderlier<Tower>,
49	Tower: TowerFamily,
50	Tower::B128: PackedTop<Tower>,
51	Hash: Digest + BlockSizeUser,
52	Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
53	Challenger_: Challenger + Default,
54{
55	let ConstraintSystem {
56		mut oracles,
57		mut table_constraints,
58		mut flushes,
59		non_zero_oracle_ids,
60		max_channel_id,
61		mut exponents,
62		..
63	} = constraint_system.clone();
64
65	// Stable sort constraint sets in ascending order by number of variables.
66	table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
67
68	let Proof { transcript } = proof;
69
70	let mut transcript = VerifierTranscript::<Challenger_>::new(transcript);
71	transcript.observe().write_slice(boundaries);
72
73	let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default());
74	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
75	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
76		&commit_meta,
77		&merkle_scheme,
78		security_bits,
79		log_inv_rate,
80	)?;
81
82	// Read polynomial commitment polynomials
83	let mut reader = transcript.message();
84	let commitment = reader.read::<Output<Hash>>()?;
85
86	// GKR exp multiplication
87	exponents.sort_by_key(|b| std::cmp::Reverse(b.n_vars(&oracles)));
88
89	let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
90
91	let mut reader = transcript.message();
92	let exp_evals = reader.read_scalar_slice(exponents.len())?;
93
94	let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
95		.into_iter()
96		.collect::<Vec<_>>();
97
98	let base_exp_output =
99		gkr_exp::batch_verify(EvaluationOrder::HighToLow, &exp_claims, &mut transcript)?;
100
101	let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
102
103	// Grand product arguments
104	// Grand products for non-zero checks
105	let mut reader = transcript.message();
106	let non_zero_products = reader.read_scalar_slice(non_zero_oracle_ids.len())?;
107	if non_zero_products
108		.iter()
109		.any(|count| *count == Tower::B128::zero())
110	{
111		bail!(Error::Zeros);
112	}
113
114	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
115		&non_zero_oracle_ids,
116		&oracles,
117		&non_zero_products,
118	)?;
119
120	// Grand products for flushing
121	let mixing_challenge = transcript.sample();
122	// TODO(cryptographers): Find a way to sample less randomness
123	let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
124
125	flushes.sort_by_key(|flush| flush.channel_id);
126	let flush_oracle_ids =
127		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
128
129	let flush_products = transcript
130		.message()
131		.read_scalar_slice(flush_oracle_ids.len())?;
132	verify_channels_balance(
133		&flushes,
134		&flush_products,
135		boundaries,
136		mixing_challenge,
137		&permutation_challenges,
138	)?;
139
140	let flush_prodcheck_claims =
141		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
142
143	// Verify grand products
144	let final_layer_claims = gkr_gpa::batch_verify(
145		EvaluationOrder::LowToHigh,
146		[flush_prodcheck_claims, non_zero_prodcheck_claims].concat(),
147		&mut transcript,
148	)?;
149
150	// Reduce non_zero_final_layer_claims to evalcheck claims
151	let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
152		chain!(flush_oracle_ids, non_zero_oracle_ids),
153		final_layer_claims,
154	)?;
155
156	// Zerocheck
157	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
158		.iter()
159		.cloned()
160		.map(constraint_set_zerocheck_claim)
161		.collect::<Result<Vec<_>, _>>()?
162		.into_iter()
163		.unzip::<_, _, Vec<_>, Vec<_>>();
164
165	let (_max_n_vars, skip_rounds) =
166		max_n_vars_and_skip_rounds(&zerocheck_claims, <FDomain<Tower>>::N_BITS);
167
168	let zerocheck_output =
169		sumcheck::batch_verify_zerocheck(&zerocheck_claims, skip_rounds, &mut transcript)?;
170
171	let zerocheck_eval_claims =
172		sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
173
174	// Evalcheck
175	let eval_claims = greedy_evalcheck::verify(
176		&mut oracles,
177		chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
178		&mut transcript,
179	)?;
180
181	// Reduce committed evaluation claims to PIOP sumcheck claims
182	let system = ring_switch::EvalClaimSystem::new(
183		&oracles,
184		&commit_meta,
185		&oracle_to_commit_index,
186		&eval_claims,
187	)?;
188
189	let ring_switch::ReducedClaim {
190		transparents,
191		sumcheck_claims: piop_sumcheck_claims,
192	} = ring_switch::verify::<_, Tower, _>(&system, &mut transcript)?;
193
194	// Prove evaluation claims using PIOP compiler
195	piop::verify(
196		&commit_meta,
197		&merkle_scheme,
198		&fri_params,
199		&commitment,
200		&transparents,
201		&piop_sumcheck_claims,
202		&mut transcript,
203	)?;
204
205	transcript.finalize()?;
206
207	Ok(())
208}
209
210pub fn max_n_vars_and_skip_rounds<F, Composition>(
211	zerocheck_claims: &[ZerocheckClaim<F, Composition>],
212	domain_bits: usize,
213) -> (usize, usize)
214where
215	F: TowerField,
216	Composition: CompositionPoly<F>,
217{
218	let max_n_vars = max_n_vars(zerocheck_claims);
219
220	// Univariate skip zerocheck domain size is degree * 2^skip_rounds, which
221	// limits skip_rounds to ceil(log2(degree)) less than domain field bits.
222	let domain_max_skip_rounds = zerocheck_claims
223		.iter()
224		.map(|claim| {
225			let log_degree = log2_ceil_usize(claim.max_individual_degree());
226			domain_bits.saturating_sub(log_degree)
227		})
228		.min()
229		.unwrap_or(0);
230
231	let max_skip_rounds = domain_max_skip_rounds.min(max_n_vars);
232	(max_n_vars, max_skip_rounds)
233}
234
235fn max_n_vars<F, Composition>(zerocheck_claims: &[ZerocheckClaim<F, Composition>]) -> usize
236where
237	F: TowerField,
238	Composition: CompositionPoly<F>,
239{
240	zerocheck_claims
241		.iter()
242		.map(|claim| claim.n_vars())
243		.max()
244		.unwrap_or(0)
245}
246
247fn verify_channels_balance<F: TowerField>(
248	flushes: &[Flush<F>],
249	flush_products: &[F],
250	boundaries: &[Boundary<F>],
251	mixing_challenge: F,
252	permutation_challenges: &[F],
253) -> Result<(), Error> {
254	if flush_products.len() != flushes.len() {
255		return Err(VerificationError::IncorrectNumberOfFlushProducts.into());
256	}
257
258	let mut flush_iter = flushes
259		.iter()
260		.zip(flush_products.iter().copied())
261		.peekable();
262	while let Some((flush, _)) = flush_iter.peek() {
263		let channel_id = flush.channel_id;
264
265		let boundary_products =
266			boundaries
267				.iter()
268				.fold((F::ONE, F::ONE), |(pull_product, push_product), boundary| {
269					let Boundary {
270						channel_id: boundary_channel_id,
271						direction,
272						multiplicity,
273						values,
274						..
275					} = boundary;
276
277					if *boundary_channel_id == channel_id {
278						let (mixed_values, _) = values.iter().fold(
279							(permutation_challenges[channel_id], F::ONE),
280							|(sum, mixing), values| {
281								(sum + mixing * values, mixing * mixing_challenge)
282							},
283						);
284
285						let mixed_values_with_multiplicity =
286							mixed_values.pow_vartime([*multiplicity]);
287
288						return match direction {
289							FlushDirection::Pull => {
290								(pull_product * mixed_values_with_multiplicity, push_product)
291							}
292							FlushDirection::Push => {
293								(pull_product, push_product * mixed_values_with_multiplicity)
294							}
295						};
296					}
297
298					(pull_product, push_product)
299				});
300
301		let (pull_product, push_product) = flush_iter
302			.peeking_take_while(|(flush, _)| flush.channel_id == channel_id)
303			.fold(boundary_products, |(pull_product, push_product), (flush, flush_product)| {
304				let flush_product_with_multiplicity =
305					flush_product.pow_vartime([flush.multiplicity]);
306				match flush.direction {
307					FlushDirection::Pull => {
308						(pull_product * flush_product_with_multiplicity, push_product)
309					}
310					FlushDirection::Push => {
311						(pull_product, push_product * flush_product_with_multiplicity)
312					}
313				}
314			});
315		if pull_product != push_product {
316			return Err(VerificationError::ChannelUnbalanced { id: channel_id }.into());
317		}
318	}
319
320	Ok(())
321}
322
323/// For each flush,
324/// - if there is a selector $S$, we are taking the Grand product of the composite $1 + S * (-1 + r + F_0 + F_1 s + F_2 s^1 + …)$
325/// - otherwise the product is over the linear combination $r + F_0 + F_1 s + F_2 s^1 + …$
326pub fn make_flush_oracles<F: TowerField>(
327	oracles: &mut MultilinearOracleSet<F>,
328	flushes: &[Flush<F>],
329	mixing_challenge: F,
330	permutation_challenges: &[F],
331) -> Result<Vec<OracleId>, Error> {
332	let mut mixing_powers = vec![F::ONE];
333	let mut flush_iter = flushes.iter();
334
335	permutation_challenges
336		.iter()
337		.enumerate()
338		.flat_map(|(channel_id, permutation_challenge)| {
339			flush_iter
340				.peeking_take_while(|flush| flush.channel_id == channel_id)
341				.map(|flush| {
342					// Check that all flushed oracles have the same number of variables
343					let mut non_const_oracles =
344						flush.oracles.iter().copied().filter_map(|id| match id {
345							OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
346							_ => None,
347						});
348
349					let first_oracle = non_const_oracles.next().ok_or(Error::EmptyFlushOracles)?;
350					let n_vars = oracles.n_vars(first_oracle);
351
352					for selector_id in &flush.selectors {
353						let got_tower_level = oracles.oracle(*selector_id).tower_level;
354						if got_tower_level != 0 {
355							return Err(Error::FlushSelectorTowerLevel {
356								oracle: *selector_id,
357								got_tower_level,
358							});
359						}
360					}
361
362					for oracle_id in non_const_oracles {
363						let oracle_n_vars = oracles.n_vars(oracle_id);
364						if oracle_n_vars != n_vars {
365							return Err(Error::ChannelFlushNvarsMismatch {
366								expected: n_vars,
367								got: oracle_n_vars,
368							});
369						}
370					}
371
372					// Compute powers of the mixing challenge
373					while mixing_powers.len() < flush.oracles.len() {
374						let last_power = *mixing_powers.last().expect(
375							"mixing_powers is initialized with one element; \
376								mixing_powers never shrinks; \
377								thus, it must not be empty",
378						);
379						mixing_powers.push(last_power * mixing_challenge);
380					}
381
382					let const_linear_combination = flush
383						.oracles
384						.iter()
385						.copied()
386						.zip(mixing_powers.iter())
387						.filter_map(|(id, coeff)| match id {
388							OracleOrConst::Const { base, .. } => Some(base * coeff),
389							_ => None,
390						})
391						.sum::<F>();
392
393					let poly = if flush.selectors.is_empty() {
394						oracles
395							.add_named(format!("flush channel_id={channel_id} linear combination"))
396							.linear_combination_with_offset(
397								n_vars,
398								*permutation_challenge + const_linear_combination,
399								flush
400									.oracles
401									.iter()
402									.zip(mixing_powers.iter().copied())
403									.filter_map(|(id, coeff)| match id {
404										OracleOrConst::Oracle(oracle_id) => {
405											Some((*oracle_id, coeff))
406										}
407										_ => None,
408									}),
409							)?
410					} else {
411						let offset = *permutation_challenge + const_linear_combination + F::ONE;
412						let arith_expr_linear = ArithExpr::Const(offset);
413						let var_offset = flush.selectors.len(); // Var's represents the selector columns.
414						let (non_const_oracles, coeffs): (Vec<_>, Vec<_>) = flush
415							.oracles
416							.iter()
417							.zip(mixing_powers.iter().copied())
418							.filter_map(|(id, coeff)| match id {
419								OracleOrConst::Oracle(id) => Some((*id, coeff)),
420								_ => None,
421							})
422							.unzip();
423
424						// Build the linear combination of the non-constant oracles.
425						let arith_expr_linear = coeffs.into_iter().enumerate().fold(
426							arith_expr_linear,
427							|linear, (offset, coeff)| {
428								linear
429									+ ArithExpr::Var(offset + var_offset) * ArithExpr::Const(coeff)
430							},
431						);
432
433						let selector = (0..var_offset)
434							.map(ArithExpr::Var)
435							.product::<ArithExpr<F>>();
436
437						// The ArithExpr is of the form 1 + S * linear_factors
438						oracles
439							.add_named(format!("flush channel_id={channel_id} composite"))
440							.composite_mle(
441								n_vars,
442								flush.selectors.iter().copied().chain(non_const_oracles),
443								(ArithExpr::Const(F::ONE) + selector * arith_expr_linear).into(),
444							)?
445					};
446					Ok(poly)
447				})
448				.collect::<Vec<_>>()
449		})
450		.collect()
451}