binius_core/ring_switch/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{Field, TowerField};
6use binius_math::{
7	B1, B8, B16, B32, B64, B128, MultilinearExtension, MultilinearQuery, PackedTop, TowerTop,
8};
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use bytes::Buf;
11use itertools::izip;
12
13use super::eq_ind::RowBatchCoeffs;
14use crate::{
15	fiat_shamir::{CanSample, Challenger},
16	piop::PIOPSumcheckClaim,
17	polynomial::MultivariatePoly,
18	ring_switch::{
19		Error, EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
20		eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra,
21	},
22	transcript::{TranscriptReader, VerifierTranscript},
23};
24
25#[derive(Debug)]
26pub struct ReducedClaim<'a, F: Field> {
27	pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
28	pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
29}
30
31pub fn verify<'a, F, Challenger_>(
32	system: &'a EvalClaimSystem<F>,
33	transcript: &mut VerifierTranscript<Challenger_>,
34) -> Result<ReducedClaim<'a, F>, Error>
35where
36	F: TowerTop + PackedTop<Scalar = F>,
37	Challenger_: Challenger,
38{
39	// Sample enough randomness to batch tensor elements corresponding to claims that share an
40	// evaluation point prefix.
41	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
42	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
43	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
44
45	// For each evaluation point prefix, receive one batched tensor algebra element and verify
46	// that it is consistent with the evaluation claims.
47	let tensor_elems =
48		verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
49
50	// Sample the row-batching randomness.
51	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
52	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
53		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
54	));
55
56	// For each original evaluation claim, receive the row-batched evaluation claim.
57	let row_batched_evals = transcript
58		.message()
59		.read_scalar_slice(system.sumcheck_claim_descs.len())?;
60
61	// Check that the row-batched evaluation claims sent by the prover are consistent with the
62	// tensor algebra sum elements previously sent.
63	let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
64		row_batched_evals.iter().copied(),
65		system.prefix_descs.len(),
66		&system.eval_claim_to_prefix_desc_index,
67	);
68	for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
69		if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
70			return Err(VerificationError::IncorrectRowBatchedSum.into());
71		}
72	}
73
74	// Create the reduced PIOP sumcheck claims.
75	let ring_switch_eq_inds = make_ring_switch_eq_inds(
76		&system.sumcheck_claim_descs,
77		&system.suffix_descs,
78		&row_batch_coeffs,
79		&mixing_coeffs,
80	)?;
81	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
82		.enumerate()
83		.map(|(idx, (claim_desc, eval))| {
84			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
85			PIOPSumcheckClaim {
86				n_vars: suffix_desc.suffix.len(),
87				committed: claim_desc.committed_idx,
88				transparent: idx,
89				sum: eval,
90			}
91		})
92		.collect::<Vec<_>>();
93
94	Ok(ReducedClaim {
95		transparents: ring_switch_eq_inds,
96		sumcheck_claims,
97	})
98}
99
100fn verify_receive_tensor_elems<F, B>(
101	system: &EvalClaimSystem<F>,
102	mixing_coeffs: &[F],
103	transcript: &mut TranscriptReader<B>,
104) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
105where
106	F: TowerTop + PackedTop<Scalar = F>,
107	B: Buf,
108{
109	let expected_tensor_elem_evals = compute_mixed_evaluations(
110		system
111			.sumcheck_claim_descs
112			.iter()
113			.map(|desc| desc.eval_claim.eval),
114		system.prefix_descs.len(),
115		&system.eval_claim_to_prefix_desc_index,
116		mixing_coeffs,
117	);
118
119	let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
120	for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
121		let kappa = desc.kappa();
122		let tensor_elem =
123			TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
124
125		let query = MultilinearQuery::<F>::expand(&desc.prefix);
126		let tensor_elem_eval =
127			MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
128				.expect("tensor_elem has length 1 << kappa")
129				.evaluate(&query)
130				.expect("query has kappa variables");
131		if tensor_elem_eval != expected_eval {
132			return Err(VerificationError::IncorrectEvaluation.into());
133		}
134
135		tensor_elems.push(tensor_elem);
136	}
137
138	Ok(tensor_elems)
139}
140
141fn compute_mixed_evaluations<F: TowerField>(
142	evals: impl ExactSizeIterator<Item = F>,
143	n_prefixes: usize,
144	eval_claim_to_prefix_desc_index: &[usize],
145	mixing_coeffs: &[F],
146) -> Vec<F> {
147	// Pre-condition
148	debug_assert!(evals.len() <= mixing_coeffs.len());
149
150	accumulate_evaluations_by_prefixes(
151		iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
152		n_prefixes,
153		eval_claim_to_prefix_desc_index,
154	)
155}
156
157fn accumulate_evaluations_by_prefixes<F: TowerField>(
158	evals: impl ExactSizeIterator<Item = F>,
159	n_prefixes: usize,
160	eval_claim_to_prefix_desc_index: &[usize],
161) -> Vec<F> {
162	// Pre-condition
163	debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
164
165	let mut batched_evals = vec![F::ZERO; n_prefixes];
166	for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
167		batched_evals[desc_index] += eval;
168	}
169	batched_evals
170}
171
172fn make_ring_switch_eq_inds<F>(
173	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
174	suffix_descs: &[EvalClaimSuffixDesc<F>],
175	row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
176	mixing_coeffs: &[F],
177) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
178where
179	F: TowerTop + PackedTop<Scalar = F>,
180{
181	iter::zip(sumcheck_claim_descs, mixing_coeffs)
182		.map(|(claim_desc, &mixing_coeff)| {
183			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
184			make_ring_switch_eq_ind(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
185		})
186		.collect()
187}
188
189fn make_ring_switch_eq_ind<F>(
190	suffix_desc: &EvalClaimSuffixDesc<F>,
191	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
192	mixing_coeff: F,
193) -> Result<Box<dyn MultivariatePoly<F>>, Error>
194where
195	F: TowerTop + PackedTop<Scalar = F>,
196{
197	let eq_ind = match F::TOWER_LEVEL - suffix_desc.kappa {
198		0 => Box::new(RingSwitchEqInd::<B1, _>::new(
199			suffix_desc.suffix.clone(),
200			row_batch_coeffs,
201			mixing_coeff,
202		)?) as Box<dyn MultivariatePoly<_>>,
203		3 => Box::new(RingSwitchEqInd::<B8, _>::new(
204			suffix_desc.suffix.clone(),
205			row_batch_coeffs,
206			mixing_coeff,
207		)?) as Box<dyn MultivariatePoly<_>>,
208		4 => Box::new(RingSwitchEqInd::<B16, _>::new(
209			suffix_desc.suffix.clone(),
210			row_batch_coeffs,
211			mixing_coeff,
212		)?) as Box<dyn MultivariatePoly<_>>,
213		5 => Box::new(RingSwitchEqInd::<B32, _>::new(
214			suffix_desc.suffix.clone(),
215			row_batch_coeffs,
216			mixing_coeff,
217		)?) as Box<dyn MultivariatePoly<_>>,
218		6 => Box::new(RingSwitchEqInd::<B64, _>::new(
219			suffix_desc.suffix.clone(),
220			row_batch_coeffs,
221			mixing_coeff,
222		)?) as Box<dyn MultivariatePoly<_>>,
223		7 => Box::new(RingSwitchEqInd::<B128, _>::new(
224			suffix_desc.suffix.clone(),
225			row_batch_coeffs,
226			mixing_coeff,
227		)?) as Box<dyn MultivariatePoly<_>>,
228		_ => {
229			return Err(Error::PackingDegreeNotSupported {
230				kappa: suffix_desc.kappa,
231			});
232		}
233	};
234	Ok(eq_ind)
235}