binius_core/ring_switch/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{Field, TowerField};
6use binius_math::{MultilinearExtension, MultilinearQuery};
7use binius_utils::checked_arithmetics::log2_ceil_usize;
8use bytes::Buf;
9use itertools::izip;
10
11use super::eq_ind::RowBatchCoeffs;
12use crate::{
13	fiat_shamir::{CanSample, Challenger},
14	piop::PIOPSumcheckClaim,
15	polynomial::MultivariatePoly,
16	ring_switch::{
17		eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra, Error,
18		EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
19	},
20	tower::{PackedTop, TowerFamily},
21	transcript::{TranscriptReader, VerifierTranscript},
22};
23
24type FExt<Tower> = <Tower as TowerFamily>::B128;
25
26#[derive(Debug)]
27pub struct ReducedClaim<'a, F: Field> {
28	pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
29	pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
30}
31
32pub fn verify<'a, F, Tower, Challenger_>(
33	system: &'a EvalClaimSystem<F>,
34	transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<ReducedClaim<'a, F>, Error>
36where
37	F: TowerField + PackedTop<Tower>,
38	Tower: TowerFamily<B128 = F>,
39	Challenger_: Challenger,
40{
41	// Sample enough randomness to batch tensor elements corresponding to claims that share an
42	// evaluation point prefix.
43	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
44	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
45	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
46
47	// For each evaluation point prefix, receive one batched tensor algebra element and verify
48	// that it is consistent with the evaluation claims.
49	let tensor_elems =
50		verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
51
52	// Sample the row-batching randomness.
53	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
54	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
55		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
56	));
57
58	// For each original evaluation claim, receive the row-batched evaluation claim.
59	let row_batched_evals = transcript
60		.message()
61		.read_scalar_slice(system.sumcheck_claim_descs.len())?;
62
63	// Check that the row-batched evaluation claims sent by the prover are consistent with the
64	// tensor algebra sum elements previously sent.
65	let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
66		row_batched_evals.iter().copied(),
67		system.prefix_descs.len(),
68		&system.eval_claim_to_prefix_desc_index,
69	);
70	for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
71		if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
72			return Err(VerificationError::IncorrectRowBatchedSum.into());
73		}
74	}
75
76	// Create the reduced PIOP sumcheck claims.
77	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>(
78		&system.sumcheck_claim_descs,
79		&system.suffix_descs,
80		&row_batch_coeffs,
81		&mixing_coeffs,
82	)?;
83	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
84		.enumerate()
85		.map(|(idx, (claim_desc, eval))| {
86			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
87			PIOPSumcheckClaim {
88				n_vars: suffix_desc.suffix.len(),
89				committed: claim_desc.committed_idx,
90				transparent: idx,
91				sum: eval,
92			}
93		})
94		.collect::<Vec<_>>();
95
96	Ok(ReducedClaim {
97		transparents: ring_switch_eq_inds,
98		sumcheck_claims,
99	})
100}
101
102fn verify_receive_tensor_elems<F, Tower, B>(
103	system: &EvalClaimSystem<F>,
104	mixing_coeffs: &[F],
105	transcript: &mut TranscriptReader<B>,
106) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
107where
108	F: TowerField + PackedTop<Tower>,
109	Tower: TowerFamily<B128 = F>,
110	B: Buf,
111{
112	let expected_tensor_elem_evals = compute_mixed_evaluations(
113		system
114			.sumcheck_claim_descs
115			.iter()
116			.map(|desc| desc.eval_claim.eval),
117		system.prefix_descs.len(),
118		&system.eval_claim_to_prefix_desc_index,
119		mixing_coeffs,
120	);
121
122	let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
123	for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
124		let kappa = desc.kappa();
125		let tensor_elem =
126			TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
127
128		let query = MultilinearQuery::<F>::expand(&desc.prefix);
129		let tensor_elem_eval =
130			MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
131				.expect("tensor_elem has length 1 << kappa")
132				.evaluate(&query)
133				.expect("query has kappa variables");
134		if tensor_elem_eval != expected_eval {
135			return Err(VerificationError::IncorrectEvaluation.into());
136		}
137
138		tensor_elems.push(tensor_elem);
139	}
140
141	Ok(tensor_elems)
142}
143
144fn compute_mixed_evaluations<F: TowerField>(
145	evals: impl ExactSizeIterator<Item = F>,
146	n_prefixes: usize,
147	eval_claim_to_prefix_desc_index: &[usize],
148	mixing_coeffs: &[F],
149) -> Vec<F> {
150	// Pre-condition
151	debug_assert!(evals.len() <= mixing_coeffs.len());
152
153	accumulate_evaluations_by_prefixes(
154		iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
155		n_prefixes,
156		eval_claim_to_prefix_desc_index,
157	)
158}
159
160fn accumulate_evaluations_by_prefixes<F: TowerField>(
161	evals: impl ExactSizeIterator<Item = F>,
162	n_prefixes: usize,
163	eval_claim_to_prefix_desc_index: &[usize],
164) -> Vec<F> {
165	// Pre-condition
166	debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
167
168	let mut batched_evals = vec![F::ZERO; n_prefixes];
169	for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
170		batched_evals[desc_index] += eval;
171	}
172	batched_evals
173}
174
175fn make_ring_switch_eq_inds<F, Tower>(
176	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
177	suffix_descs: &[EvalClaimSuffixDesc<F>],
178	row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
179	mixing_coeffs: &[F],
180) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
181where
182	F: TowerField + PackedTop<Tower>,
183	Tower: TowerFamily<B128 = F>,
184{
185	iter::zip(sumcheck_claim_descs, mixing_coeffs)
186		.map(|(claim_desc, &mixing_coeff)| {
187			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
188			make_ring_switch_eq_ind::<Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
189		})
190		.collect()
191}
192
193fn make_ring_switch_eq_ind<Tower>(
194	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
195	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
196	mixing_coeff: FExt<Tower>,
197) -> Result<Box<dyn MultivariatePoly<FExt<Tower>>>, Error>
198where
199	Tower: TowerFamily,
200	FExt<Tower>: PackedTop<Tower>,
201{
202	let eq_ind = match suffix_desc.kappa {
203		7 => Box::new(RingSwitchEqInd::<Tower::B1, _>::new(
204			suffix_desc.suffix.clone(),
205			row_batch_coeffs,
206			mixing_coeff,
207		)?) as Box<dyn MultivariatePoly<_>>,
208		4 => Box::new(RingSwitchEqInd::<Tower::B8, _>::new(
209			suffix_desc.suffix.clone(),
210			row_batch_coeffs,
211			mixing_coeff,
212		)?) as Box<dyn MultivariatePoly<_>>,
213		3 => Box::new(RingSwitchEqInd::<Tower::B16, _>::new(
214			suffix_desc.suffix.clone(),
215			row_batch_coeffs,
216			mixing_coeff,
217		)?) as Box<dyn MultivariatePoly<_>>,
218		2 => Box::new(RingSwitchEqInd::<Tower::B32, _>::new(
219			suffix_desc.suffix.clone(),
220			row_batch_coeffs,
221			mixing_coeff,
222		)?) as Box<dyn MultivariatePoly<_>>,
223		1 => Box::new(RingSwitchEqInd::<Tower::B64, _>::new(
224			suffix_desc.suffix.clone(),
225			row_batch_coeffs,
226			mixing_coeff,
227		)?) as Box<dyn MultivariatePoly<_>>,
228		0 => Box::new(RingSwitchEqInd::<Tower::B128, _>::new(
229			suffix_desc.suffix.clone(),
230			row_batch_coeffs,
231			mixing_coeff,
232		)?) as Box<dyn MultivariatePoly<_>>,
233		_ => {
234			return Err(Error::PackingDegreeNotSupported {
235				kappa: suffix_desc.kappa,
236			})
237		}
238	};
239	Ok(eq_ind)
240}