1use std::{iter, sync::Arc};
4
5use binius_field::{Field, TowerField};
6use binius_math::{
7 B1, B8, B16, B32, B64, B128, MultilinearExtension, MultilinearQuery, PackedTop, TowerTop,
8};
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use bytes::Buf;
11use itertools::izip;
12
13use super::eq_ind::RowBatchCoeffs;
14use crate::{
15 fiat_shamir::{CanSample, Challenger},
16 piop::PIOPSumcheckClaim,
17 polynomial::MultivariatePoly,
18 ring_switch::{
19 Error, EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
20 eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra,
21 },
22 transcript::{TranscriptReader, VerifierTranscript},
23};
24
25#[derive(Debug)]
26pub struct ReducedClaim<'a, F: Field> {
27 pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
28 pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
29}
30
31pub fn verify<'a, F, Challenger_>(
32 system: &'a EvalClaimSystem<F>,
33 transcript: &mut VerifierTranscript<Challenger_>,
34) -> Result<ReducedClaim<'a, F>, Error>
35where
36 F: TowerTop + PackedTop<Scalar = F>,
37 Challenger_: Challenger,
38{
39 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
42 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
43 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
44
45 let tensor_elems =
48 verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
49
50 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
52 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
53 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
54 ));
55
56 let row_batched_evals = transcript
58 .message()
59 .read_scalar_slice(system.sumcheck_claim_descs.len())?;
60
61 let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
64 row_batched_evals.iter().copied(),
65 system.prefix_descs.len(),
66 &system.eval_claim_to_prefix_desc_index,
67 );
68 for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
69 if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
70 return Err(VerificationError::IncorrectRowBatchedSum.into());
71 }
72 }
73
74 let ring_switch_eq_inds = make_ring_switch_eq_inds(
76 &system.sumcheck_claim_descs,
77 &system.suffix_descs,
78 &row_batch_coeffs,
79 &mixing_coeffs,
80 )?;
81 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
82 .enumerate()
83 .map(|(idx, (claim_desc, eval))| {
84 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
85 PIOPSumcheckClaim {
86 n_vars: suffix_desc.suffix.len(),
87 committed: claim_desc.committed_idx,
88 transparent: idx,
89 sum: eval,
90 }
91 })
92 .collect::<Vec<_>>();
93
94 Ok(ReducedClaim {
95 transparents: ring_switch_eq_inds,
96 sumcheck_claims,
97 })
98}
99
100fn verify_receive_tensor_elems<F, B>(
101 system: &EvalClaimSystem<F>,
102 mixing_coeffs: &[F],
103 transcript: &mut TranscriptReader<B>,
104) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
105where
106 F: TowerTop + PackedTop<Scalar = F>,
107 B: Buf,
108{
109 let expected_tensor_elem_evals = compute_mixed_evaluations(
110 system
111 .sumcheck_claim_descs
112 .iter()
113 .map(|desc| desc.eval_claim.eval),
114 system.prefix_descs.len(),
115 &system.eval_claim_to_prefix_desc_index,
116 mixing_coeffs,
117 );
118
119 let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
120 for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
121 let kappa = desc.kappa();
122 let tensor_elem =
123 TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
124
125 let query = MultilinearQuery::<F>::expand(&desc.prefix);
126 let tensor_elem_eval =
127 MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
128 .expect("tensor_elem has length 1 << kappa")
129 .evaluate(&query)
130 .expect("query has kappa variables");
131 if tensor_elem_eval != expected_eval {
132 return Err(VerificationError::IncorrectEvaluation.into());
133 }
134
135 tensor_elems.push(tensor_elem);
136 }
137
138 Ok(tensor_elems)
139}
140
141fn compute_mixed_evaluations<F: TowerField>(
142 evals: impl ExactSizeIterator<Item = F>,
143 n_prefixes: usize,
144 eval_claim_to_prefix_desc_index: &[usize],
145 mixing_coeffs: &[F],
146) -> Vec<F> {
147 debug_assert!(evals.len() <= mixing_coeffs.len());
149
150 accumulate_evaluations_by_prefixes(
151 iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
152 n_prefixes,
153 eval_claim_to_prefix_desc_index,
154 )
155}
156
157fn accumulate_evaluations_by_prefixes<F: TowerField>(
158 evals: impl ExactSizeIterator<Item = F>,
159 n_prefixes: usize,
160 eval_claim_to_prefix_desc_index: &[usize],
161) -> Vec<F> {
162 debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
164
165 let mut batched_evals = vec![F::ZERO; n_prefixes];
166 for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
167 batched_evals[desc_index] += eval;
168 }
169 batched_evals
170}
171
172fn make_ring_switch_eq_inds<F>(
173 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
174 suffix_descs: &[EvalClaimSuffixDesc<F>],
175 row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
176 mixing_coeffs: &[F],
177) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
178where
179 F: TowerTop + PackedTop<Scalar = F>,
180{
181 iter::zip(sumcheck_claim_descs, mixing_coeffs)
182 .map(|(claim_desc, &mixing_coeff)| {
183 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
184 make_ring_switch_eq_ind(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
185 })
186 .collect()
187}
188
189fn make_ring_switch_eq_ind<F>(
190 suffix_desc: &EvalClaimSuffixDesc<F>,
191 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
192 mixing_coeff: F,
193) -> Result<Box<dyn MultivariatePoly<F>>, Error>
194where
195 F: TowerTop + PackedTop<Scalar = F>,
196{
197 let eq_ind = match F::TOWER_LEVEL - suffix_desc.kappa {
198 0 => Box::new(RingSwitchEqInd::<B1, _>::new(
199 suffix_desc.suffix.clone(),
200 row_batch_coeffs,
201 mixing_coeff,
202 )?) as Box<dyn MultivariatePoly<_>>,
203 3 => Box::new(RingSwitchEqInd::<B8, _>::new(
204 suffix_desc.suffix.clone(),
205 row_batch_coeffs,
206 mixing_coeff,
207 )?) as Box<dyn MultivariatePoly<_>>,
208 4 => Box::new(RingSwitchEqInd::<B16, _>::new(
209 suffix_desc.suffix.clone(),
210 row_batch_coeffs,
211 mixing_coeff,
212 )?) as Box<dyn MultivariatePoly<_>>,
213 5 => Box::new(RingSwitchEqInd::<B32, _>::new(
214 suffix_desc.suffix.clone(),
215 row_batch_coeffs,
216 mixing_coeff,
217 )?) as Box<dyn MultivariatePoly<_>>,
218 6 => Box::new(RingSwitchEqInd::<B64, _>::new(
219 suffix_desc.suffix.clone(),
220 row_batch_coeffs,
221 mixing_coeff,
222 )?) as Box<dyn MultivariatePoly<_>>,
223 7 => Box::new(RingSwitchEqInd::<B128, _>::new(
224 suffix_desc.suffix.clone(),
225 row_batch_coeffs,
226 mixing_coeff,
227 )?) as Box<dyn MultivariatePoly<_>>,
228 _ => {
229 return Err(Error::PackingDegreeNotSupported {
230 kappa: suffix_desc.kappa,
231 });
232 }
233 };
234 Ok(eq_ind)
235}