binius_core/ring_switch/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{
6	tower::{PackedTop, TowerFamily},
7	Field, TowerField,
8};
9use binius_math::{MultilinearExtension, MultilinearQuery};
10use binius_utils::checked_arithmetics::log2_ceil_usize;
11use bytes::Buf;
12use itertools::izip;
13
14use super::eq_ind::RowBatchCoeffs;
15use crate::{
16	fiat_shamir::{CanSample, Challenger},
17	piop::PIOPSumcheckClaim,
18	polynomial::MultivariatePoly,
19	ring_switch::{
20		eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra, Error,
21		EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
22	},
23	transcript::{TranscriptReader, VerifierTranscript},
24};
25
26type FExt<Tower> = <Tower as TowerFamily>::B128;
27
28#[derive(Debug)]
29pub struct ReducedClaim<'a, F: Field> {
30	pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
31	pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
32}
33
34pub fn verify<'a, F, Tower, Challenger_>(
35	system: &'a EvalClaimSystem<F>,
36	transcript: &mut VerifierTranscript<Challenger_>,
37) -> Result<ReducedClaim<'a, F>, Error>
38where
39	F: TowerField + PackedTop<Tower>,
40	Tower: TowerFamily<B128 = F>,
41	Challenger_: Challenger,
42{
43	// Sample enough randomness to batch tensor elements corresponding to claims that share an
44	// evaluation point prefix.
45	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
46	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
47	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
48
49	// For each evaluation point prefix, receive one batched tensor algebra element and verify
50	// that it is consistent with the evaluation claims.
51	let tensor_elems =
52		verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
53
54	// Sample the row-batching randomness.
55	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
56	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
57		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
58	));
59
60	// For each original evaluation claim, receive the row-batched evaluation claim.
61	let row_batched_evals = transcript
62		.message()
63		.read_scalar_slice(system.sumcheck_claim_descs.len())?;
64
65	// Check that the row-batched evaluation claims sent by the prover are consistent with the
66	// tensor algebra sum elements previously sent.
67	let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
68		row_batched_evals.iter().copied(),
69		system.prefix_descs.len(),
70		&system.eval_claim_to_prefix_desc_index,
71	);
72	for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
73		if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
74			return Err(VerificationError::IncorrectRowBatchedSum.into());
75		}
76	}
77
78	// Create the reduced PIOP sumcheck claims.
79	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>(
80		&system.sumcheck_claim_descs,
81		&system.suffix_descs,
82		&row_batch_coeffs,
83		&mixing_coeffs,
84	)?;
85	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
86		.enumerate()
87		.map(|(idx, (claim_desc, eval))| {
88			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
89			PIOPSumcheckClaim {
90				n_vars: suffix_desc.suffix.len(),
91				committed: claim_desc.committed_idx,
92				transparent: idx,
93				sum: eval,
94			}
95		})
96		.collect::<Vec<_>>();
97
98	Ok(ReducedClaim {
99		transparents: ring_switch_eq_inds,
100		sumcheck_claims,
101	})
102}
103
104fn verify_receive_tensor_elems<F, Tower, B>(
105	system: &EvalClaimSystem<F>,
106	mixing_coeffs: &[F],
107	transcript: &mut TranscriptReader<B>,
108) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
109where
110	F: TowerField + PackedTop<Tower>,
111	Tower: TowerFamily<B128 = F>,
112	B: Buf,
113{
114	let expected_tensor_elem_evals = compute_mixed_evaluations(
115		system
116			.sumcheck_claim_descs
117			.iter()
118			.map(|desc| desc.eval_claim.eval),
119		system.prefix_descs.len(),
120		&system.eval_claim_to_prefix_desc_index,
121		mixing_coeffs,
122	);
123
124	let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
125	for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
126		let kappa = desc.kappa();
127		let tensor_elem =
128			TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
129
130		let query = MultilinearQuery::<F>::expand(&desc.prefix);
131		let tensor_elem_eval =
132			MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
133				.expect("tensor_elem has length 1 << kappa")
134				.evaluate(&query)
135				.expect("query has kappa variables");
136		if tensor_elem_eval != expected_eval {
137			return Err(VerificationError::IncorrectEvaluation.into());
138		}
139
140		tensor_elems.push(tensor_elem);
141	}
142
143	Ok(tensor_elems)
144}
145
146fn compute_mixed_evaluations<F: TowerField>(
147	evals: impl ExactSizeIterator<Item = F>,
148	n_prefixes: usize,
149	eval_claim_to_prefix_desc_index: &[usize],
150	mixing_coeffs: &[F],
151) -> Vec<F> {
152	// Pre-condition
153	debug_assert!(evals.len() <= mixing_coeffs.len());
154
155	accumulate_evaluations_by_prefixes(
156		iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
157		n_prefixes,
158		eval_claim_to_prefix_desc_index,
159	)
160}
161
162fn accumulate_evaluations_by_prefixes<F: TowerField>(
163	evals: impl ExactSizeIterator<Item = F>,
164	n_prefixes: usize,
165	eval_claim_to_prefix_desc_index: &[usize],
166) -> Vec<F> {
167	// Pre-condition
168	debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
169
170	let mut batched_evals = vec![F::ZERO; n_prefixes];
171	for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
172		batched_evals[desc_index] += eval;
173	}
174	batched_evals
175}
176
177fn make_ring_switch_eq_inds<F, Tower>(
178	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
179	suffix_descs: &[EvalClaimSuffixDesc<F>],
180	row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
181	mixing_coeffs: &[F],
182) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
183where
184	F: TowerField + PackedTop<Tower>,
185	Tower: TowerFamily<B128 = F>,
186{
187	iter::zip(sumcheck_claim_descs, mixing_coeffs)
188		.map(|(claim_desc, &mixing_coeff)| {
189			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
190			make_ring_switch_eq_ind::<Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
191		})
192		.collect()
193}
194
195fn make_ring_switch_eq_ind<Tower>(
196	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
197	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
198	mixing_coeff: FExt<Tower>,
199) -> Result<Box<dyn MultivariatePoly<FExt<Tower>>>, Error>
200where
201	Tower: TowerFamily,
202	FExt<Tower>: PackedTop<Tower>,
203{
204	let eq_ind = match suffix_desc.kappa {
205		7 => Box::new(RingSwitchEqInd::<Tower::B1, _>::new(
206			suffix_desc.suffix.clone(),
207			row_batch_coeffs,
208			mixing_coeff,
209		)?) as Box<dyn MultivariatePoly<_>>,
210		4 => Box::new(RingSwitchEqInd::<Tower::B8, _>::new(
211			suffix_desc.suffix.clone(),
212			row_batch_coeffs,
213			mixing_coeff,
214		)?) as Box<dyn MultivariatePoly<_>>,
215		3 => Box::new(RingSwitchEqInd::<Tower::B16, _>::new(
216			suffix_desc.suffix.clone(),
217			row_batch_coeffs,
218			mixing_coeff,
219		)?) as Box<dyn MultivariatePoly<_>>,
220		2 => Box::new(RingSwitchEqInd::<Tower::B32, _>::new(
221			suffix_desc.suffix.clone(),
222			row_batch_coeffs,
223			mixing_coeff,
224		)?) as Box<dyn MultivariatePoly<_>>,
225		1 => Box::new(RingSwitchEqInd::<Tower::B64, _>::new(
226			suffix_desc.suffix.clone(),
227			row_batch_coeffs,
228			mixing_coeff,
229		)?) as Box<dyn MultivariatePoly<_>>,
230		0 => Box::new(RingSwitchEqInd::<Tower::B128, _>::new(
231			suffix_desc.suffix.clone(),
232			row_batch_coeffs,
233			mixing_coeff,
234		)?) as Box<dyn MultivariatePoly<_>>,
235		_ => {
236			return Err(Error::PackingDegreeNotSupported {
237				kappa: suffix_desc.kappa,
238			})
239		}
240	};
241	Ok(eq_ind)
242}