1use binius_field::{
4 tower::{PackedTop, TowerFamily, TowerUnderlier},
5 BinaryField, PackedField, TowerField,
6};
7use binius_hash::PseudoCompressionFunction;
8use binius_math::{ArithExpr, CompositionPoly, EvaluationOrder};
9use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
10use digest::{core_api::BlockSizeUser, Digest, Output};
11use itertools::{chain, Itertools};
12use tracing::instrument;
13
14use super::{
15 channel::{Boundary, OracleOrConst},
16 error::{Error, VerificationError},
17 exp, ConstraintSystem, Proof,
18};
19use crate::{
20 constraint_system::{
21 channel::{Flush, FlushDirection},
22 common::{FDomain, FEncode, FExt},
23 },
24 fiat_shamir::{CanSample, Challenger},
25 merkle_tree::BinaryMerkleTreeScheme,
26 oracle::{MultilinearOracleSet, OracleId},
27 piop,
28 protocols::{
29 gkr_exp,
30 gkr_gpa::{self},
31 greedy_evalcheck,
32 sumcheck::{self, constraint_set_zerocheck_claim, ZerocheckClaim},
33 },
34 ring_switch,
35 transcript::VerifierTranscript,
36};
37
38#[instrument("constraint_system::verify", skip_all, level = "debug")]
40pub fn verify<U, Tower, Hash, Compress, Challenger_>(
41 constraint_system: &ConstraintSystem<FExt<Tower>>,
42 log_inv_rate: usize,
43 security_bits: usize,
44 boundaries: &[Boundary<FExt<Tower>>],
45 proof: Proof,
46) -> Result<(), Error>
47where
48 U: TowerUnderlier<Tower>,
49 Tower: TowerFamily,
50 Tower::B128: PackedTop<Tower>,
51 Hash: Digest + BlockSizeUser,
52 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
53 Challenger_: Challenger + Default,
54{
55 let ConstraintSystem {
56 mut oracles,
57 mut table_constraints,
58 mut flushes,
59 non_zero_oracle_ids,
60 max_channel_id,
61 mut exponents,
62 ..
63 } = constraint_system.clone();
64
65 table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
67
68 let Proof { transcript } = proof;
69
70 let mut transcript = VerifierTranscript::<Challenger_>::new(transcript);
71 transcript.observe().write_slice(boundaries);
72
73 let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default());
74 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
75 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
76 &commit_meta,
77 &merkle_scheme,
78 security_bits,
79 log_inv_rate,
80 )?;
81
82 let mut reader = transcript.message();
84 let commitment = reader.read::<Output<Hash>>()?;
85
86 exponents.sort_by_key(|b| std::cmp::Reverse(b.n_vars(&oracles)));
88
89 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
90
91 let mut reader = transcript.message();
92 let exp_evals = reader.read_scalar_slice(exponents.len())?;
93
94 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
95 .into_iter()
96 .collect::<Vec<_>>();
97
98 let base_exp_output =
99 gkr_exp::batch_verify(EvaluationOrder::HighToLow, &exp_claims, &mut transcript)?;
100
101 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
102
103 let mut reader = transcript.message();
106 let non_zero_products = reader.read_scalar_slice(non_zero_oracle_ids.len())?;
107 if non_zero_products
108 .iter()
109 .any(|count| *count == Tower::B128::zero())
110 {
111 bail!(Error::Zeros);
112 }
113
114 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
115 &non_zero_oracle_ids,
116 &oracles,
117 &non_zero_products,
118 )?;
119
120 let mixing_challenge = transcript.sample();
122 let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
124
125 flushes.sort_by_key(|flush| flush.channel_id);
126 let flush_oracle_ids =
127 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
128
129 let flush_products = transcript
130 .message()
131 .read_scalar_slice(flush_oracle_ids.len())?;
132 verify_channels_balance(
133 &flushes,
134 &flush_products,
135 boundaries,
136 mixing_challenge,
137 &permutation_challenges,
138 )?;
139
140 let flush_prodcheck_claims =
141 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
142
143 let final_layer_claims = gkr_gpa::batch_verify(
145 EvaluationOrder::LowToHigh,
146 [flush_prodcheck_claims, non_zero_prodcheck_claims].concat(),
147 &mut transcript,
148 )?;
149
150 let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
152 chain!(flush_oracle_ids, non_zero_oracle_ids),
153 final_layer_claims,
154 )?;
155
156 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
158 .iter()
159 .cloned()
160 .map(constraint_set_zerocheck_claim)
161 .collect::<Result<Vec<_>, _>>()?
162 .into_iter()
163 .unzip::<_, _, Vec<_>, Vec<_>>();
164
165 let (_max_n_vars, skip_rounds) =
166 max_n_vars_and_skip_rounds(&zerocheck_claims, <FDomain<Tower>>::N_BITS);
167
168 let zerocheck_output =
169 sumcheck::batch_verify_zerocheck(&zerocheck_claims, skip_rounds, &mut transcript)?;
170
171 let zerocheck_eval_claims =
172 sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
173
174 let eval_claims = greedy_evalcheck::verify(
176 &mut oracles,
177 chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
178 &mut transcript,
179 )?;
180
181 let system = ring_switch::EvalClaimSystem::new(
183 &oracles,
184 &commit_meta,
185 &oracle_to_commit_index,
186 &eval_claims,
187 )?;
188
189 let ring_switch::ReducedClaim {
190 transparents,
191 sumcheck_claims: piop_sumcheck_claims,
192 } = ring_switch::verify::<_, Tower, _>(&system, &mut transcript)?;
193
194 piop::verify(
196 &commit_meta,
197 &merkle_scheme,
198 &fri_params,
199 &commitment,
200 &transparents,
201 &piop_sumcheck_claims,
202 &mut transcript,
203 )?;
204
205 transcript.finalize()?;
206
207 Ok(())
208}
209
210pub fn max_n_vars_and_skip_rounds<F, Composition>(
211 zerocheck_claims: &[ZerocheckClaim<F, Composition>],
212 domain_bits: usize,
213) -> (usize, usize)
214where
215 F: TowerField,
216 Composition: CompositionPoly<F>,
217{
218 let max_n_vars = max_n_vars(zerocheck_claims);
219
220 let domain_max_skip_rounds = zerocheck_claims
223 .iter()
224 .map(|claim| {
225 let log_degree = log2_ceil_usize(claim.max_individual_degree());
226 domain_bits.saturating_sub(log_degree)
227 })
228 .min()
229 .unwrap_or(0);
230
231 let max_skip_rounds = domain_max_skip_rounds.min(max_n_vars);
232 (max_n_vars, max_skip_rounds)
233}
234
235fn max_n_vars<F, Composition>(zerocheck_claims: &[ZerocheckClaim<F, Composition>]) -> usize
236where
237 F: TowerField,
238 Composition: CompositionPoly<F>,
239{
240 zerocheck_claims
241 .iter()
242 .map(|claim| claim.n_vars())
243 .max()
244 .unwrap_or(0)
245}
246
247fn verify_channels_balance<F: TowerField>(
248 flushes: &[Flush<F>],
249 flush_products: &[F],
250 boundaries: &[Boundary<F>],
251 mixing_challenge: F,
252 permutation_challenges: &[F],
253) -> Result<(), Error> {
254 if flush_products.len() != flushes.len() {
255 return Err(VerificationError::IncorrectNumberOfFlushProducts.into());
256 }
257
258 let mut flush_iter = flushes
259 .iter()
260 .zip(flush_products.iter().copied())
261 .peekable();
262 while let Some((flush, _)) = flush_iter.peek() {
263 let channel_id = flush.channel_id;
264
265 let boundary_products =
266 boundaries
267 .iter()
268 .fold((F::ONE, F::ONE), |(pull_product, push_product), boundary| {
269 let Boundary {
270 channel_id: boundary_channel_id,
271 direction,
272 multiplicity,
273 values,
274 ..
275 } = boundary;
276
277 if *boundary_channel_id == channel_id {
278 let (mixed_values, _) = values.iter().fold(
279 (permutation_challenges[channel_id], F::ONE),
280 |(sum, mixing), values| {
281 (sum + mixing * values, mixing * mixing_challenge)
282 },
283 );
284
285 let mixed_values_with_multiplicity =
286 mixed_values.pow_vartime([*multiplicity]);
287
288 return match direction {
289 FlushDirection::Pull => {
290 (pull_product * mixed_values_with_multiplicity, push_product)
291 }
292 FlushDirection::Push => {
293 (pull_product, push_product * mixed_values_with_multiplicity)
294 }
295 };
296 }
297
298 (pull_product, push_product)
299 });
300
301 let (pull_product, push_product) = flush_iter
302 .peeking_take_while(|(flush, _)| flush.channel_id == channel_id)
303 .fold(boundary_products, |(pull_product, push_product), (flush, flush_product)| {
304 let flush_product_with_multiplicity =
305 flush_product.pow_vartime([flush.multiplicity]);
306 match flush.direction {
307 FlushDirection::Pull => {
308 (pull_product * flush_product_with_multiplicity, push_product)
309 }
310 FlushDirection::Push => {
311 (pull_product, push_product * flush_product_with_multiplicity)
312 }
313 }
314 });
315 if pull_product != push_product {
316 return Err(VerificationError::ChannelUnbalanced { id: channel_id }.into());
317 }
318 }
319
320 Ok(())
321}
322
323pub fn make_flush_oracles<F: TowerField>(
327 oracles: &mut MultilinearOracleSet<F>,
328 flushes: &[Flush<F>],
329 mixing_challenge: F,
330 permutation_challenges: &[F],
331) -> Result<Vec<OracleId>, Error> {
332 let mut mixing_powers = vec![F::ONE];
333 let mut flush_iter = flushes.iter();
334
335 permutation_challenges
336 .iter()
337 .enumerate()
338 .flat_map(|(channel_id, permutation_challenge)| {
339 flush_iter
340 .peeking_take_while(|flush| flush.channel_id == channel_id)
341 .map(|flush| {
342 let mut non_const_oracles =
344 flush.oracles.iter().copied().filter_map(|id| match id {
345 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
346 _ => None,
347 });
348
349 let first_oracle = non_const_oracles.next().ok_or(Error::EmptyFlushOracles)?;
350 let n_vars = oracles.n_vars(first_oracle);
351
352 for selector_id in &flush.selectors {
353 let got_tower_level = oracles.oracle(*selector_id).tower_level;
354 if got_tower_level != 0 {
355 return Err(Error::FlushSelectorTowerLevel {
356 oracle: *selector_id,
357 got_tower_level,
358 });
359 }
360 }
361
362 for oracle_id in non_const_oracles {
363 let oracle_n_vars = oracles.n_vars(oracle_id);
364 if oracle_n_vars != n_vars {
365 return Err(Error::ChannelFlushNvarsMismatch {
366 expected: n_vars,
367 got: oracle_n_vars,
368 });
369 }
370 }
371
372 while mixing_powers.len() < flush.oracles.len() {
374 let last_power = *mixing_powers.last().expect(
375 "mixing_powers is initialized with one element; \
376 mixing_powers never shrinks; \
377 thus, it must not be empty",
378 );
379 mixing_powers.push(last_power * mixing_challenge);
380 }
381
382 let const_linear_combination = flush
383 .oracles
384 .iter()
385 .copied()
386 .zip(mixing_powers.iter())
387 .filter_map(|(id, coeff)| match id {
388 OracleOrConst::Const { base, .. } => Some(base * coeff),
389 _ => None,
390 })
391 .sum::<F>();
392
393 let poly = if flush.selectors.is_empty() {
394 oracles
395 .add_named(format!("flush channel_id={channel_id} linear combination"))
396 .linear_combination_with_offset(
397 n_vars,
398 *permutation_challenge + const_linear_combination,
399 flush
400 .oracles
401 .iter()
402 .zip(mixing_powers.iter().copied())
403 .filter_map(|(id, coeff)| match id {
404 OracleOrConst::Oracle(oracle_id) => {
405 Some((*oracle_id, coeff))
406 }
407 _ => None,
408 }),
409 )?
410 } else {
411 let offset = *permutation_challenge + const_linear_combination + F::ONE;
412 let arith_expr_linear = ArithExpr::Const(offset);
413 let var_offset = flush.selectors.len(); let (non_const_oracles, coeffs): (Vec<_>, Vec<_>) = flush
415 .oracles
416 .iter()
417 .zip(mixing_powers.iter().copied())
418 .filter_map(|(id, coeff)| match id {
419 OracleOrConst::Oracle(id) => Some((*id, coeff)),
420 _ => None,
421 })
422 .unzip();
423
424 let arith_expr_linear = coeffs.into_iter().enumerate().fold(
426 arith_expr_linear,
427 |linear, (offset, coeff)| {
428 linear
429 + ArithExpr::Var(offset + var_offset) * ArithExpr::Const(coeff)
430 },
431 );
432
433 let selector = (0..var_offset)
434 .map(ArithExpr::Var)
435 .product::<ArithExpr<F>>();
436
437 oracles
439 .add_named(format!("flush channel_id={channel_id} composite"))
440 .composite_mle(
441 n_vars,
442 flush.selectors.iter().copied().chain(non_const_oracles),
443 (ArithExpr::Const(F::ONE) + selector * arith_expr_linear).into(),
444 )?
445 };
446 Ok(poly)
447 })
448 .collect::<Vec<_>>()
449 })
450 .collect()
451}