binius_core/constraint_system/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	BinaryField, PackedField, TowerField,
5	tower::{PackedTop, TowerFamily, TowerUnderlier},
6};
7use binius_hash::PseudoCompressionFunction;
8use binius_math::{ArithExpr, CompositionPoly, EvaluationOrder};
9use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
10use digest::{Digest, Output, core_api::BlockSizeUser};
11use itertools::{Itertools, chain};
12use tracing::instrument;
13
14use super::{
15	ConstraintSystem, Proof,
16	channel::{Boundary, OracleOrConst},
17	error::{Error, VerificationError},
18	exp::{self, reorder_exponents},
19};
20use crate::{
21	constraint_system::{
22		channel::{Flush, FlushDirection},
23		common::{FDomain, FEncode, FExt},
24	},
25	fiat_shamir::{CanSample, Challenger},
26	merkle_tree::BinaryMerkleTreeScheme,
27	oracle::{MultilinearOracleSet, OracleId},
28	piop,
29	protocols::{
30		gkr_exp,
31		gkr_gpa::{self},
32		greedy_evalcheck,
33		sumcheck::{self, ZerocheckClaim, constraint_set_zerocheck_claim},
34	},
35	ring_switch,
36	transcript::VerifierTranscript,
37};
38
39/// Verifies a proof against a constraint system.
40#[instrument("constraint_system::verify", skip_all, level = "debug")]
41pub fn verify<U, Tower, Hash, Compress, Challenger_>(
42	constraint_system: &ConstraintSystem<FExt<Tower>>,
43	log_inv_rate: usize,
44	security_bits: usize,
45	boundaries: &[Boundary<FExt<Tower>>],
46	proof: Proof,
47) -> Result<(), Error>
48where
49	U: TowerUnderlier<Tower>,
50	Tower: TowerFamily,
51	Tower::B128: PackedTop<Tower>,
52	Hash: Digest + BlockSizeUser,
53	Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
54	Challenger_: Challenger + Default,
55{
56	let ConstraintSystem {
57		mut oracles,
58		mut table_constraints,
59		mut flushes,
60		non_zero_oracle_ids,
61		max_channel_id,
62		mut exponents,
63		..
64	} = constraint_system.clone();
65
66	// Stable sort constraint sets in ascending order by number of variables.
67	table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
68
69	let Proof { transcript } = proof;
70
71	let mut transcript = VerifierTranscript::<Challenger_>::new(transcript);
72	transcript.observe().write_slice(boundaries);
73
74	let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default());
75	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
76	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
77		&commit_meta,
78		&merkle_scheme,
79		security_bits,
80		log_inv_rate,
81	)?;
82
83	// Read polynomial commitment polynomials
84	let mut reader = transcript.message();
85	let commitment = reader.read::<Output<Hash>>()?;
86
87	// GKR exp multiplication
88	reorder_exponents(&mut exponents, &oracles);
89
90	let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
91
92	let mut reader = transcript.message();
93	let exp_evals = reader.read_scalar_slice(exponents.len())?;
94
95	let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
96		.into_iter()
97		.collect::<Vec<_>>();
98
99	let base_exp_output =
100		gkr_exp::batch_verify(EvaluationOrder::HighToLow, &exp_claims, &mut transcript)?;
101
102	let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
103
104	// Grand product arguments
105	// Grand products for non-zero checks
106	let mut reader = transcript.message();
107	let non_zero_products = reader.read_scalar_slice(non_zero_oracle_ids.len())?;
108	if non_zero_products
109		.iter()
110		.any(|count| *count == Tower::B128::zero())
111	{
112		bail!(Error::Zeros);
113	}
114
115	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
116		&non_zero_oracle_ids,
117		&oracles,
118		&non_zero_products,
119	)?;
120
121	// Grand products for flushing
122	let mixing_challenge = transcript.sample();
123	// TODO(cryptographers): Find a way to sample less randomness
124	let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
125
126	flushes.sort_by_key(|flush| flush.channel_id);
127	let flush_oracle_ids =
128		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
129
130	let flush_products = transcript
131		.message()
132		.read_scalar_slice(flush_oracle_ids.len())?;
133	verify_channels_balance(
134		&flushes,
135		&flush_products,
136		boundaries,
137		mixing_challenge,
138		&permutation_challenges,
139	)?;
140
141	let flush_prodcheck_claims =
142		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
143
144	// Verify grand products
145	let final_layer_claims = gkr_gpa::batch_verify(
146		EvaluationOrder::HighToLow,
147		[flush_prodcheck_claims, non_zero_prodcheck_claims].concat(),
148		&mut transcript,
149	)?;
150
151	// Reduce non_zero_final_layer_claims to evalcheck claims
152	let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
153		chain!(flush_oracle_ids, non_zero_oracle_ids),
154		final_layer_claims,
155	)?;
156
157	// Zerocheck
158	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
159		.iter()
160		.cloned()
161		.map(constraint_set_zerocheck_claim)
162		.collect::<Result<Vec<_>, _>>()?
163		.into_iter()
164		.unzip::<_, _, Vec<_>, Vec<_>>();
165
166	let (_max_n_vars, skip_rounds) =
167		max_n_vars_and_skip_rounds(&zerocheck_claims, <FDomain<Tower>>::N_BITS);
168
169	let zerocheck_output =
170		sumcheck::batch_verify_zerocheck(&zerocheck_claims, skip_rounds, &mut transcript)?;
171
172	let zerocheck_eval_claims =
173		sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
174
175	// Evalcheck
176	let eval_claims = greedy_evalcheck::verify(
177		&mut oracles,
178		chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
179		&mut transcript,
180	)?;
181
182	// Reduce committed evaluation claims to PIOP sumcheck claims
183	let system = ring_switch::EvalClaimSystem::new(
184		&oracles,
185		&commit_meta,
186		&oracle_to_commit_index,
187		&eval_claims,
188	)?;
189
190	let ring_switch::ReducedClaim {
191		transparents,
192		sumcheck_claims: piop_sumcheck_claims,
193	} = ring_switch::verify::<_, Tower, _>(&system, &mut transcript)?;
194
195	// Prove evaluation claims using PIOP compiler
196	piop::verify(
197		&commit_meta,
198		&merkle_scheme,
199		&fri_params,
200		&commitment,
201		&transparents,
202		&piop_sumcheck_claims,
203		&mut transcript,
204	)?;
205
206	transcript.finalize()?;
207
208	Ok(())
209}
210
211pub fn max_n_vars_and_skip_rounds<F, Composition>(
212	zerocheck_claims: &[ZerocheckClaim<F, Composition>],
213	domain_bits: usize,
214) -> (usize, usize)
215where
216	F: TowerField,
217	Composition: CompositionPoly<F>,
218{
219	let max_n_vars = max_n_vars(zerocheck_claims);
220
221	// Univariate skip zerocheck domain size is degree * 2^skip_rounds, which
222	// limits skip_rounds to ceil(log2(degree)) less than domain field bits.
223	let domain_max_skip_rounds = zerocheck_claims
224		.iter()
225		.map(|claim| {
226			let log_degree = log2_ceil_usize(claim.max_individual_degree());
227			domain_bits.saturating_sub(log_degree)
228		})
229		.min()
230		.unwrap_or(0);
231
232	let max_skip_rounds = domain_max_skip_rounds.min(max_n_vars);
233	(max_n_vars, max_skip_rounds)
234}
235
236fn max_n_vars<F, Composition>(zerocheck_claims: &[ZerocheckClaim<F, Composition>]) -> usize
237where
238	F: TowerField,
239	Composition: CompositionPoly<F>,
240{
241	zerocheck_claims
242		.iter()
243		.map(|claim| claim.n_vars())
244		.max()
245		.unwrap_or(0)
246}
247
248fn verify_channels_balance<F: TowerField>(
249	flushes: &[Flush<F>],
250	flush_products: &[F],
251	boundaries: &[Boundary<F>],
252	mixing_challenge: F,
253	permutation_challenges: &[F],
254) -> Result<(), Error> {
255	if flush_products.len() != flushes.len() {
256		return Err(VerificationError::IncorrectNumberOfFlushProducts.into());
257	}
258
259	let mut flush_iter = flushes
260		.iter()
261		.zip(flush_products.iter().copied())
262		.peekable();
263	while let Some((flush, _)) = flush_iter.peek() {
264		let channel_id = flush.channel_id;
265
266		let boundary_products =
267			boundaries
268				.iter()
269				.fold((F::ONE, F::ONE), |(pull_product, push_product), boundary| {
270					let Boundary {
271						channel_id: boundary_channel_id,
272						direction,
273						multiplicity,
274						values,
275						..
276					} = boundary;
277
278					if *boundary_channel_id == channel_id {
279						let (mixed_values, _) = values.iter().fold(
280							(permutation_challenges[channel_id], F::ONE),
281							|(sum, mixing), values| {
282								(sum + mixing * values, mixing * mixing_challenge)
283							},
284						);
285
286						let mixed_values_with_multiplicity =
287							mixed_values.pow_vartime([*multiplicity]);
288
289						return match direction {
290							FlushDirection::Pull => {
291								(pull_product * mixed_values_with_multiplicity, push_product)
292							}
293							FlushDirection::Push => {
294								(pull_product, push_product * mixed_values_with_multiplicity)
295							}
296						};
297					}
298
299					(pull_product, push_product)
300				});
301
302		let (pull_product, push_product) = flush_iter
303			.peeking_take_while(|(flush, _)| flush.channel_id == channel_id)
304			.fold(boundary_products, |(pull_product, push_product), (flush, flush_product)| {
305				let flush_product_with_multiplicity =
306					flush_product.pow_vartime([flush.multiplicity]);
307				match flush.direction {
308					FlushDirection::Pull => {
309						(pull_product * flush_product_with_multiplicity, push_product)
310					}
311					FlushDirection::Push => {
312						(pull_product, push_product * flush_product_with_multiplicity)
313					}
314				}
315			});
316		if pull_product != push_product {
317			return Err(VerificationError::ChannelUnbalanced { id: channel_id }.into());
318		}
319	}
320
321	Ok(())
322}
323
324/// For each flush,
325/// - if there is a selector $S$, we are taking the Grand product of the composite $1 + S * (-1 + r
326///   + F_0 + F_1 s + F_2 s^1 + …)$
327/// - otherwise the product is over the linear combination $r + F_0 + F_1 s + F_2 s^1 + …$
328pub fn make_flush_oracles<F: TowerField>(
329	oracles: &mut MultilinearOracleSet<F>,
330	flushes: &[Flush<F>],
331	mixing_challenge: F,
332	permutation_challenges: &[F],
333) -> Result<Vec<OracleId>, Error> {
334	let mut mixing_powers = vec![F::ONE];
335	let mut flush_iter = flushes.iter();
336
337	permutation_challenges
338		.iter()
339		.enumerate()
340		.flat_map(|(channel_id, permutation_challenge)| {
341			flush_iter
342				.peeking_take_while(|flush| flush.channel_id == channel_id)
343				.map(|flush| {
344					// Check that all flushed oracles have the same number of variables
345					let mut non_const_oracles =
346						flush.oracles.iter().copied().filter_map(|id| match id {
347							OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
348							_ => None,
349						});
350
351					let first_oracle = non_const_oracles.next().ok_or(Error::EmptyFlushOracles)?;
352					let n_vars = oracles.n_vars(first_oracle);
353
354					for selector_id in &flush.selectors {
355						let got_tower_level = oracles[*selector_id].tower_level;
356						if got_tower_level != 0 {
357							return Err(Error::FlushSelectorTowerLevel {
358								oracle: *selector_id,
359								got_tower_level,
360							});
361						}
362					}
363
364					for oracle_id in non_const_oracles {
365						let oracle_n_vars = oracles.n_vars(oracle_id);
366						if oracle_n_vars != n_vars {
367							return Err(Error::ChannelFlushNvarsMismatch {
368								expected: n_vars,
369								got: oracle_n_vars,
370							});
371						}
372					}
373
374					// Compute powers of the mixing challenge
375					while mixing_powers.len() < flush.oracles.len() {
376						let last_power = *mixing_powers.last().expect(
377							"mixing_powers is initialized with one element; \
378								mixing_powers never shrinks; \
379								thus, it must not be empty",
380						);
381						mixing_powers.push(last_power * mixing_challenge);
382					}
383
384					let const_linear_combination = flush
385						.oracles
386						.iter()
387						.copied()
388						.zip(mixing_powers.iter())
389						.filter_map(|(id, coeff)| match id {
390							OracleOrConst::Const { base, .. } => Some(base * coeff),
391							_ => None,
392						})
393						.sum::<F>();
394
395					let poly = if flush.selectors.is_empty() {
396						oracles
397							.add_named(format!("flush channel_id={channel_id} linear combination"))
398							.linear_combination_with_offset(
399								n_vars,
400								*permutation_challenge + const_linear_combination,
401								flush
402									.oracles
403									.iter()
404									.zip(mixing_powers.iter().copied())
405									.filter_map(|(id, coeff)| match id {
406										OracleOrConst::Oracle(oracle_id) => {
407											Some((*oracle_id, coeff))
408										}
409										_ => None,
410									}),
411							)?
412					} else {
413						let offset = *permutation_challenge + const_linear_combination + F::ONE;
414						let arith_expr_linear = ArithExpr::Const(offset);
415						let var_offset = flush.selectors.len(); // Var's represents the selector columns.
416						let (non_const_oracles, coeffs): (Vec<_>, Vec<_>) = flush
417							.oracles
418							.iter()
419							.zip(mixing_powers.iter().copied())
420							.filter_map(|(id, coeff)| match id {
421								OracleOrConst::Oracle(id) => Some((*id, coeff)),
422								_ => None,
423							})
424							.unzip();
425
426						// Build the linear combination of the non-constant oracles.
427						let arith_expr_linear = coeffs.into_iter().enumerate().fold(
428							arith_expr_linear,
429							|linear, (offset, coeff)| {
430								linear
431									+ ArithExpr::Var(offset + var_offset) * ArithExpr::Const(coeff)
432							},
433						);
434
435						let selector = (0..var_offset)
436							.map(ArithExpr::Var)
437							.product::<ArithExpr<F>>();
438
439						// The ArithExpr is of the form 1 + S * linear_factors
440						oracles
441							.add_named(format!("flush channel_id={channel_id} composite"))
442							.composite_mle(
443								n_vars,
444								flush.selectors.iter().copied().chain(non_const_oracles),
445								(ArithExpr::Const(F::ONE) + selector * arith_expr_linear).into(),
446							)?
447					};
448					Ok(poly)
449				})
450				.collect::<Vec<_>>()
451		})
452		.collect()
453}