1use std::collections::hash_map::Entry;
4
5use binius_field::{
6 BinaryField, PackedField, TowerField,
7 tower::{PackedTop, TowerFamily, TowerUnderlier},
8};
9use binius_hash::PseudoCompressionFunction;
10use binius_math::{ArithExpr, CompositionPoly, EvaluationOrder};
11use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
12use digest::{Digest, Output, OutputSizeUser, core_api::BlockSizeUser};
13use itertools::{Itertools, chain, izip};
14use tracing::instrument;
15
16use super::{
17 ConstraintSystem, Proof,
18 channel::{Boundary, OracleOrConst},
19 error::{Error, VerificationError},
20 exp::{self, reorder_exponents},
21};
22use crate::{
23 constraint_system::{
24 TableSizeSpec,
25 channel::{Flush, FlushDirection},
26 common::{FDomain, FEncode, FExt},
27 },
28 fiat_shamir::{CanSample, Challenger},
29 merkle_tree::BinaryMerkleTreeScheme,
30 oracle::{
31 ConstraintSetBuilder, MultilinearOracleSet, MultilinearPolyVariant, OracleId,
32 SizedConstraintSet,
33 },
34 piop,
35 protocols::{
36 evalcheck::{EvalPoint, EvalcheckMultilinearClaim},
37 gkr_exp,
38 gkr_gpa::{self},
39 greedy_evalcheck,
40 sumcheck::{
41 self, MLEcheckClaimsWithMeta, ZerocheckClaim, constraint_set_mlecheck_claims,
42 constraint_set_zerocheck_claim,
43 eq_ind::{self, ClaimsSortingOrder, reduce_to_regular_sumchecks},
44 front_loaded,
45 },
46 },
47 ring_switch,
48 transcript::VerifierTranscript,
49 transparent::step_down::StepDown,
50};
51
52#[instrument("constraint_system::verify", skip_all, level = "debug")]
54#[allow(clippy::too_many_arguments)]
55pub fn verify<U, Tower, Hash, Compress, Challenger_>(
56 constraint_system: &ConstraintSystem<FExt<Tower>>,
57 log_inv_rate: usize,
58 security_bits: usize,
59 constraint_system_digest: &Output<Hash>,
60 boundaries: &[Boundary<FExt<Tower>>],
61 proof: Proof,
62) -> Result<(), Error>
63where
64 U: TowerUnderlier<Tower>,
65 Tower: TowerFamily,
66 Tower::B128: binius_math::TowerTop + binius_math::PackedTop + PackedTop<Tower>,
67 Hash: Digest + BlockSizeUser + OutputSizeUser,
68 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
69 Challenger_: Challenger + Default,
70{
71 let ConstraintSystem {
72 oracles,
73 table_constraints,
74 mut flushes,
75 mut non_zero_oracle_ids,
76 channel_count,
77 mut exponents,
78 table_size_specs,
79 } = constraint_system.clone();
80
81 let Proof { transcript } = proof;
82
83 let mut transcript = VerifierTranscript::<Challenger_>::new(transcript);
84 transcript
85 .observe()
86 .write_slice(constraint_system_digest.as_ref());
87 transcript.observe().write_slice(boundaries);
88
89 let table_count = table_size_specs.len();
90 let mut reader = transcript.message();
91 let table_sizes: Vec<usize> = reader.read_vec(table_count)?;
92
93 constraint_system.check_table_sizes(&table_sizes)?;
94 let mut oracles = oracles.instantiate(&table_sizes)?;
95
96 flushes.retain(|flush| table_sizes[flush.table_id] > 0);
102 flushes.sort_by_key(|flush| flush.channel_id);
103
104 non_zero_oracle_ids.retain(|oracle| !oracles.is_zero_sized(*oracle));
105 exponents.retain(|exp| !oracles.is_zero_sized(exp.exp_result_id));
106
107 let mut table_constraints = table_constraints
108 .into_iter()
109 .filter_map(|u| {
110 if table_sizes[u.table_id] == 0 {
111 None
112 } else {
113 let n_vars = u.log_values_per_row + log2_ceil_usize(table_sizes[u.table_id]);
114 Some(SizedConstraintSet::new(n_vars, u))
115 }
116 })
117 .collect::<Vec<_>>();
118 table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
120
121 reorder_exponents(&mut exponents, &oracles);
123
124 let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default());
125 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
126 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
127 &commit_meta,
128 &merkle_scheme,
129 security_bits,
130 log_inv_rate,
131 )?;
132
133 let mut reader = transcript.message();
135 let commitment = reader.read::<Output<Hash>>()?;
136
137 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
138
139 let mut reader = transcript.message();
140 let exp_evals = reader.read_scalar_slice(exponents.len())?;
141
142 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
143 .into_iter()
144 .collect::<Vec<_>>();
145
146 let base_exp_output =
147 gkr_exp::batch_verify(EvaluationOrder::HighToLow, &exp_claims, &mut transcript)?;
148
149 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
150
151 let mut reader = transcript.message();
154 let non_zero_products = reader.read_scalar_slice(non_zero_oracle_ids.len())?;
155 if non_zero_products
156 .iter()
157 .any(|count| *count == Tower::B128::zero())
158 {
159 bail!(Error::Zeros);
160 }
161
162 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
163 &non_zero_oracle_ids,
164 &oracles,
165 &non_zero_products,
166 )?;
167
168 let mixing_challenge = transcript.sample();
170 let permutation_challenges = transcript.sample_vec(channel_count);
172
173 flushes.retain(|flush| table_sizes[flush.table_id] > 0);
174 flushes.sort_by_key(|flush| flush.channel_id);
175 let _ =
176 augment_flush_po2_step_down(&mut oracles, &mut flushes, &table_size_specs, &table_sizes)?;
177 let flush_oracle_ids =
178 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
179
180 let flush_products = transcript
181 .message()
182 .read_scalar_slice(flush_oracle_ids.len())?;
183 verify_channels_balance(
184 &flushes,
185 &flush_products,
186 boundaries,
187 mixing_challenge,
188 &permutation_challenges,
189 )?;
190
191 let flush_prodcheck_claims =
192 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
193
194 let final_layer_claims = gkr_gpa::batch_verify(
196 EvaluationOrder::HighToLow,
197 [flush_prodcheck_claims, non_zero_prodcheck_claims].concat(),
198 &mut transcript,
199 )?;
200
201 let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
203 chain!(flush_oracle_ids.clone(), non_zero_oracle_ids),
204 final_layer_claims,
205 )?;
206
207 let mut flush_prodcheck_eval_claims = prodcheck_eval_claims;
208
209 let prodcheck_eval_claims = flush_prodcheck_eval_claims.split_off(flush_oracle_ids.len());
210
211 let flush_eval_claims = reduce_flush_evalcheck_claims::<Tower, Challenger_>(
212 flush_prodcheck_eval_claims,
213 &oracles,
214 &mut transcript,
215 )?;
216
217 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
219 .iter()
220 .cloned()
221 .map(constraint_set_zerocheck_claim)
222 .collect::<Result<Vec<_>, _>>()?
223 .into_iter()
224 .unzip::<_, _, Vec<_>, Vec<_>>();
225
226 let (_max_n_vars, skip_rounds) =
227 max_n_vars_and_skip_rounds(&zerocheck_claims, <FDomain<Tower>>::N_BITS);
228
229 let zerocheck_output =
230 sumcheck::batch_verify_zerocheck(&zerocheck_claims, skip_rounds, &mut transcript)?;
231
232 let zerocheck_eval_claims =
233 sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
234
235 let eval_claims = greedy_evalcheck::verify(
237 &mut oracles,
238 chain!(flush_eval_claims, prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
239 &mut transcript,
240 )?;
241
242 let system = ring_switch::EvalClaimSystem::new(
244 &oracles,
245 &commit_meta,
246 &oracle_to_commit_index,
247 &eval_claims,
248 )?;
249
250 let ring_switch::ReducedClaim {
251 transparents,
252 sumcheck_claims: piop_sumcheck_claims,
253 } = ring_switch::verify(&system, &mut transcript)?;
254
255 piop::verify(
257 &commit_meta,
258 &merkle_scheme,
259 &fri_params,
260 &commitment,
261 &transparents,
262 &piop_sumcheck_claims,
263 &mut transcript,
264 )?;
265
266 transcript.finalize()?;
267
268 Ok(())
269}
270
271pub fn max_n_vars_and_skip_rounds<F, Composition>(
272 zerocheck_claims: &[ZerocheckClaim<F, Composition>],
273 domain_bits: usize,
274) -> (usize, usize)
275where
276 F: TowerField,
277 Composition: CompositionPoly<F>,
278{
279 let max_n_vars = max_n_vars(zerocheck_claims);
280
281 let domain_max_skip_rounds = zerocheck_claims
284 .iter()
285 .map(|claim| {
286 let log_degree = log2_ceil_usize(claim.max_individual_degree());
287 domain_bits.saturating_sub(log_degree)
288 })
289 .min()
290 .unwrap_or(0);
291
292 let max_skip_rounds = domain_max_skip_rounds.min(max_n_vars);
293 (max_n_vars, max_skip_rounds)
294}
295
296fn max_n_vars<F, Composition>(zerocheck_claims: &[ZerocheckClaim<F, Composition>]) -> usize
297where
298 F: TowerField,
299 Composition: CompositionPoly<F>,
300{
301 zerocheck_claims
302 .iter()
303 .map(|claim| claim.n_vars())
304 .max()
305 .unwrap_or(0)
306}
307
308fn verify_channels_balance<F: TowerField>(
309 flushes: &[Flush<F>],
310 flush_products: &[F],
311 boundaries: &[Boundary<F>],
312 mixing_challenge: F,
313 permutation_challenges: &[F],
314) -> Result<(), Error> {
315 if flush_products.len() != flushes.len() {
316 return Err(VerificationError::IncorrectNumberOfFlushProducts.into());
317 }
318
319 let mut flush_iter = flushes
320 .iter()
321 .zip(flush_products.iter().copied())
322 .peekable();
323 while let Some((flush, _)) = flush_iter.peek() {
324 let channel_id = flush.channel_id;
325
326 let boundary_products =
327 boundaries
328 .iter()
329 .fold((F::ONE, F::ONE), |(pull_product, push_product), boundary| {
330 let Boundary {
331 channel_id: boundary_channel_id,
332 direction,
333 multiplicity,
334 values,
335 ..
336 } = boundary;
337
338 if *boundary_channel_id == channel_id {
339 let (mixed_values, _) = values.iter().fold(
340 (permutation_challenges[channel_id], F::ONE),
341 |(sum, mixing), values| {
342 (sum + mixing * values, mixing * mixing_challenge)
343 },
344 );
345
346 let mixed_values_with_multiplicity =
347 mixed_values.pow_vartime([*multiplicity]);
348
349 return match direction {
350 FlushDirection::Pull => {
351 (pull_product * mixed_values_with_multiplicity, push_product)
352 }
353 FlushDirection::Push => {
354 (pull_product, push_product * mixed_values_with_multiplicity)
355 }
356 };
357 }
358
359 (pull_product, push_product)
360 });
361
362 let (pull_product, push_product) = flush_iter
363 .peeking_take_while(|(flush, _)| flush.channel_id == channel_id)
364 .fold(boundary_products, |(pull_product, push_product), (flush, flush_product)| {
365 let flush_product_with_multiplicity =
366 flush_product.pow_vartime([flush.multiplicity]);
367 match flush.direction {
368 FlushDirection::Pull => {
369 (pull_product * flush_product_with_multiplicity, push_product)
370 }
371 FlushDirection::Push => {
372 (pull_product, push_product * flush_product_with_multiplicity)
373 }
374 }
375 });
376 if pull_product != push_product {
377 return Err(VerificationError::ChannelUnbalanced { id: channel_id }.into());
378 }
379 }
380
381 Ok(())
382}
383
384pub fn augment_flush_po2_step_down<F: TowerField>(
391 oracles: &mut MultilinearOracleSet<F>,
392 flushes: &mut [Flush<F>],
393 table_size_specs: &[TableSizeSpec],
394 table_sizes: &[usize],
395) -> Result<Vec<(OracleId, StepDown)>, Error> {
396 use std::collections::HashMap;
397
398 use crate::transparent::step_down::StepDown;
399
400 let mut step_down_oracles = HashMap::<(usize, usize), OracleId>::new();
402 let mut step_down_polys = Vec::new();
403
404 for flush in flushes.iter() {
406 let table_id = flush.table_id;
407 let table_size = table_sizes[table_id];
408 let table_spec = &table_size_specs[table_id];
409
410 if matches!(table_spec, TableSizeSpec::Arbitrary) {
412 let log_values_per_row = flush.log_values_per_row;
413 let key = (table_id, log_values_per_row);
414
415 if let Entry::Vacant(e) = step_down_oracles.entry(key) {
417 let log_capacity = log2_ceil_usize(table_size);
418 let n_vars = log_capacity + log_values_per_row;
419 let size = table_size << log_values_per_row;
420
421 let step_down_poly = StepDown::new(n_vars, size)?;
422 let oracle_id = oracles
423 .add_named(format!("stepdown_table_{table_id}_log_values_{log_values_per_row}"))
424 .transparent(step_down_poly.clone())?;
425
426 step_down_polys.push((oracle_id, step_down_poly));
427 e.insert(oracle_id);
428 }
429 }
430 }
431
432 for flush in flushes.iter_mut() {
434 let table_id = flush.table_id;
435 let table_spec = &table_size_specs[table_id];
436
437 if matches!(table_spec, TableSizeSpec::Arbitrary) {
438 let key = (table_id, flush.log_values_per_row);
439 if let Some(&oracle_id) = step_down_oracles.get(&key) {
440 flush.selectors.push(oracle_id);
441 }
442 }
443 }
444
445 Ok(step_down_polys)
446}
447
448pub fn make_flush_oracles<F: TowerField>(
453 oracles: &mut MultilinearOracleSet<F>,
454 flushes: &[Flush<F>],
455 mixing_challenge: F,
456 permutation_challenges: &[F],
457) -> Result<Vec<OracleId>, Error> {
458 let mut mixing_powers = vec![F::ONE];
459 let mut flush_iter = flushes.iter();
460
461 permutation_challenges
462 .iter()
463 .enumerate()
464 .flat_map(|(channel_id, permutation_challenge)| {
465 flush_iter
466 .peeking_take_while(|flush| flush.channel_id == channel_id)
467 .map(|flush| {
468 let mut non_const_oracles =
470 flush.oracles.iter().copied().filter_map(|id| match id {
471 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
472 _ => None,
473 });
474
475 let first_oracle = non_const_oracles.next().ok_or(Error::EmptyFlushOracles)?;
476 let n_vars = oracles.n_vars(first_oracle);
477
478 for selector_id in &flush.selectors {
479 let got_tower_level = oracles[*selector_id].tower_level;
480 if got_tower_level != 0 {
481 return Err(Error::FlushSelectorTowerLevel {
482 oracle: *selector_id,
483 got_tower_level,
484 });
485 }
486 }
487
488 for oracle_id in non_const_oracles {
489 let oracle_n_vars = oracles.n_vars(oracle_id);
490 if oracle_n_vars != n_vars {
491 return Err(Error::ChannelFlushNvarsMismatch {
492 expected: n_vars,
493 got: oracle_n_vars,
494 });
495 }
496 }
497
498 while mixing_powers.len() < flush.oracles.len() {
500 let last_power = *mixing_powers.last().expect(
501 "mixing_powers is initialized with one element; \
502 mixing_powers never shrinks; \
503 thus, it must not be empty",
504 );
505 mixing_powers.push(last_power * mixing_challenge);
506 }
507
508 let const_linear_combination = flush
509 .oracles
510 .iter()
511 .copied()
512 .zip(mixing_powers.iter())
513 .filter_map(|(id, coeff)| match id {
514 OracleOrConst::Const { base, .. } => Some(base * coeff),
515 _ => None,
516 })
517 .sum::<F>();
518
519 let poly = if flush.selectors.is_empty() {
520 oracles
521 .add_named(format!("flush channel_id={channel_id} linear combination"))
522 .linear_combination_with_offset(
523 n_vars,
524 *permutation_challenge + const_linear_combination,
525 flush
526 .oracles
527 .iter()
528 .zip(mixing_powers.iter().copied())
529 .filter_map(|(id, coeff)| match id {
530 OracleOrConst::Oracle(oracle_id) => {
531 Some((*oracle_id, coeff))
532 }
533 _ => None,
534 }),
535 )?
536 } else {
537 let offset = *permutation_challenge + const_linear_combination + F::ONE;
538 let arith_expr_linear = ArithExpr::Const(offset);
539 let var_offset = flush.selectors.len(); let (non_const_oracles, coeffs): (Vec<_>, Vec<_>) = flush
541 .oracles
542 .iter()
543 .zip(mixing_powers.iter().copied())
544 .filter_map(|(id, coeff)| match id {
545 OracleOrConst::Oracle(id) => Some((*id, coeff)),
546 _ => None,
547 })
548 .unzip();
549
550 let arith_expr_linear = coeffs.into_iter().enumerate().fold(
552 arith_expr_linear,
553 |linear, (offset, coeff)| {
554 linear
555 + ArithExpr::Var(offset + var_offset) * ArithExpr::Const(coeff)
556 },
557 );
558
559 let selector = (0..var_offset)
560 .map(ArithExpr::Var)
561 .product::<ArithExpr<F>>();
562
563 oracles
565 .add_named(format!("flush channel_id={channel_id} composite"))
566 .composite_mle(
567 n_vars,
568 flush.selectors.iter().copied().chain(non_const_oracles),
569 (ArithExpr::Const(F::ONE) + selector * arith_expr_linear).into(),
570 )?
571 };
572 Ok(poly)
573 })
574 .collect::<Vec<_>>()
575 })
576 .collect()
577}
578
579fn reduce_flush_evalcheck_claims<Tower: TowerFamily, Challenger_>(
580 claims: Vec<EvalcheckMultilinearClaim<FExt<Tower>>>,
581 oracles: &MultilinearOracleSet<FExt<Tower>>,
582 transcript: &mut VerifierTranscript<Challenger_>,
583) -> Result<Vec<EvalcheckMultilinearClaim<FExt<Tower>>>, Error>
584where
585 Challenger_: Challenger + Default,
586{
587 let mut linear_claims = Vec::new();
588
589 #[allow(clippy::type_complexity)]
590 let mut new_mlechecks_constraints: Vec<(
591 EvalPoint<FExt<Tower>>,
592 ConstraintSetBuilder<FExt<Tower>>,
593 )> = Vec::new();
594
595 for claim in &claims {
596 match &oracles[claim.id].variant {
597 MultilinearPolyVariant::LinearCombination(_) => linear_claims.push(claim.clone()),
598 MultilinearPolyVariant::Composite(composite) => {
599 let eval_point = claim.eval_point.clone();
600 let eval = claim.eval;
601
602 let position = new_mlechecks_constraints
603 .iter()
604 .position(|(ep, _)| *ep == eval_point)
605 .unwrap_or(new_mlechecks_constraints.len());
606
607 let oracle_ids = composite.inner().clone();
608
609 let exp = <_ as CompositionPoly<FExt<Tower>>>::expression(composite.c());
610 if let Some((_, constraint_builder)) = new_mlechecks_constraints.get_mut(position) {
611 constraint_builder.add_sumcheck(oracle_ids, exp, eval);
612 } else {
613 let mut new_builder = ConstraintSetBuilder::new();
614 new_builder.add_sumcheck(oracle_ids, exp, eval);
615 new_mlechecks_constraints.push((eval_point.clone(), new_builder));
616 }
617 }
618 _ => unreachable!(),
619 }
620 }
621
622 let new_mlechecks_constraints = new_mlechecks_constraints;
623
624 let mut eq_ind_challenges = Vec::with_capacity(new_mlechecks_constraints.len());
625 let mut constraint_sets = Vec::with_capacity(new_mlechecks_constraints.len());
626
627 for (ep, builder) in new_mlechecks_constraints {
628 eq_ind_challenges.push(ep.to_vec());
629 constraint_sets.push(builder.build_one(oracles)?)
630 }
631
632 let MLEcheckClaimsWithMeta {
633 claims: mlecheck_claims,
634 metas,
635 } = constraint_set_mlecheck_claims(constraint_sets)?;
636
637 let mut new_evalcheck_claims = Vec::new();
638
639 for (eq_ind_challenges, mlecheck_claim, meta) in
640 izip!(&eq_ind_challenges, mlecheck_claims, metas)
641 {
642 let mlecheck_claim = vec![mlecheck_claim];
643
644 let batch_sumcheck_verifier = front_loaded::BatchVerifier::new(
645 &reduce_to_regular_sumchecks(&mlecheck_claim)?,
646 transcript,
647 )?;
648 let mut sumcheck_output = batch_sumcheck_verifier.run(transcript)?;
649
650 sumcheck_output.challenges.reverse();
652
653 let eq_ind_output = eq_ind::verify_sumcheck_outputs(
654 ClaimsSortingOrder::AscendingVars,
655 &mlecheck_claim,
656 eq_ind_challenges,
657 sumcheck_output,
658 )?;
659
660 let evalcheck_claims =
661 sumcheck::make_eval_claims(EvaluationOrder::HighToLow, vec![meta], eq_ind_output)?;
662 new_evalcheck_claims.extend(evalcheck_claims)
663 }
664
665 Ok(chain!(new_evalcheck_claims.into_iter(), linear_claims.into_iter()).collect::<Vec<_>>())
666}