1use std::{iter, sync::Arc};
4
5use binius_field::{Field, TowerField};
6use binius_math::{MultilinearExtension, MultilinearQuery};
7use binius_utils::checked_arithmetics::log2_ceil_usize;
8use bytes::Buf;
9use itertools::izip;
10
11use super::eq_ind::RowBatchCoeffs;
12use crate::{
13 fiat_shamir::{CanSample, Challenger},
14 piop::PIOPSumcheckClaim,
15 polynomial::MultivariatePoly,
16 ring_switch::{
17 eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra, Error,
18 EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
19 },
20 tower::{PackedTop, TowerFamily},
21 transcript::{TranscriptReader, VerifierTranscript},
22};
23
24type FExt<Tower> = <Tower as TowerFamily>::B128;
25
26#[derive(Debug)]
27pub struct ReducedClaim<'a, F: Field> {
28 pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
29 pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
30}
31
32pub fn verify<'a, F, Tower, Challenger_>(
33 system: &'a EvalClaimSystem<F>,
34 transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<ReducedClaim<'a, F>, Error>
36where
37 F: TowerField + PackedTop<Tower>,
38 Tower: TowerFamily<B128 = F>,
39 Challenger_: Challenger,
40{
41 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
44 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
45 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
46
47 let tensor_elems =
50 verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
51
52 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
54 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
55 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
56 ));
57
58 let row_batched_evals = transcript
60 .message()
61 .read_scalar_slice(system.sumcheck_claim_descs.len())?;
62
63 let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
66 row_batched_evals.iter().copied(),
67 system.prefix_descs.len(),
68 &system.eval_claim_to_prefix_desc_index,
69 );
70 for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
71 if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
72 return Err(VerificationError::IncorrectRowBatchedSum.into());
73 }
74 }
75
76 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>(
78 &system.sumcheck_claim_descs,
79 &system.suffix_descs,
80 &row_batch_coeffs,
81 &mixing_coeffs,
82 )?;
83 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
84 .enumerate()
85 .map(|(idx, (claim_desc, eval))| {
86 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
87 PIOPSumcheckClaim {
88 n_vars: suffix_desc.suffix.len(),
89 committed: claim_desc.committed_idx,
90 transparent: idx,
91 sum: eval,
92 }
93 })
94 .collect::<Vec<_>>();
95
96 Ok(ReducedClaim {
97 transparents: ring_switch_eq_inds,
98 sumcheck_claims,
99 })
100}
101
102fn verify_receive_tensor_elems<F, Tower, B>(
103 system: &EvalClaimSystem<F>,
104 mixing_coeffs: &[F],
105 transcript: &mut TranscriptReader<B>,
106) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
107where
108 F: TowerField + PackedTop<Tower>,
109 Tower: TowerFamily<B128 = F>,
110 B: Buf,
111{
112 let expected_tensor_elem_evals = compute_mixed_evaluations(
113 system
114 .sumcheck_claim_descs
115 .iter()
116 .map(|desc| desc.eval_claim.eval),
117 system.prefix_descs.len(),
118 &system.eval_claim_to_prefix_desc_index,
119 mixing_coeffs,
120 );
121
122 let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
123 for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
124 let kappa = desc.kappa();
125 let tensor_elem =
126 TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
127
128 let query = MultilinearQuery::<F>::expand(&desc.prefix);
129 let tensor_elem_eval =
130 MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
131 .expect("tensor_elem has length 1 << kappa")
132 .evaluate(&query)
133 .expect("query has kappa variables");
134 if tensor_elem_eval != expected_eval {
135 return Err(VerificationError::IncorrectEvaluation.into());
136 }
137
138 tensor_elems.push(tensor_elem);
139 }
140
141 Ok(tensor_elems)
142}
143
144fn compute_mixed_evaluations<F: TowerField>(
145 evals: impl ExactSizeIterator<Item = F>,
146 n_prefixes: usize,
147 eval_claim_to_prefix_desc_index: &[usize],
148 mixing_coeffs: &[F],
149) -> Vec<F> {
150 debug_assert!(evals.len() <= mixing_coeffs.len());
152
153 accumulate_evaluations_by_prefixes(
154 iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
155 n_prefixes,
156 eval_claim_to_prefix_desc_index,
157 )
158}
159
160fn accumulate_evaluations_by_prefixes<F: TowerField>(
161 evals: impl ExactSizeIterator<Item = F>,
162 n_prefixes: usize,
163 eval_claim_to_prefix_desc_index: &[usize],
164) -> Vec<F> {
165 debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
167
168 let mut batched_evals = vec![F::ZERO; n_prefixes];
169 for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
170 batched_evals[desc_index] += eval;
171 }
172 batched_evals
173}
174
175fn make_ring_switch_eq_inds<F, Tower>(
176 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
177 suffix_descs: &[EvalClaimSuffixDesc<F>],
178 row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
179 mixing_coeffs: &[F],
180) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
181where
182 F: TowerField + PackedTop<Tower>,
183 Tower: TowerFamily<B128 = F>,
184{
185 iter::zip(sumcheck_claim_descs, mixing_coeffs)
186 .map(|(claim_desc, &mixing_coeff)| {
187 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
188 make_ring_switch_eq_ind::<Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
189 })
190 .collect()
191}
192
193fn make_ring_switch_eq_ind<Tower>(
194 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
195 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
196 mixing_coeff: FExt<Tower>,
197) -> Result<Box<dyn MultivariatePoly<FExt<Tower>>>, Error>
198where
199 Tower: TowerFamily,
200 FExt<Tower>: PackedTop<Tower>,
201{
202 let eq_ind = match suffix_desc.kappa {
203 7 => Box::new(RingSwitchEqInd::<Tower::B1, _>::new(
204 suffix_desc.suffix.clone(),
205 row_batch_coeffs,
206 mixing_coeff,
207 )?) as Box<dyn MultivariatePoly<_>>,
208 4 => Box::new(RingSwitchEqInd::<Tower::B8, _>::new(
209 suffix_desc.suffix.clone(),
210 row_batch_coeffs,
211 mixing_coeff,
212 )?) as Box<dyn MultivariatePoly<_>>,
213 3 => Box::new(RingSwitchEqInd::<Tower::B16, _>::new(
214 suffix_desc.suffix.clone(),
215 row_batch_coeffs,
216 mixing_coeff,
217 )?) as Box<dyn MultivariatePoly<_>>,
218 2 => Box::new(RingSwitchEqInd::<Tower::B32, _>::new(
219 suffix_desc.suffix.clone(),
220 row_batch_coeffs,
221 mixing_coeff,
222 )?) as Box<dyn MultivariatePoly<_>>,
223 1 => Box::new(RingSwitchEqInd::<Tower::B64, _>::new(
224 suffix_desc.suffix.clone(),
225 row_batch_coeffs,
226 mixing_coeff,
227 )?) as Box<dyn MultivariatePoly<_>>,
228 0 => Box::new(RingSwitchEqInd::<Tower::B128, _>::new(
229 suffix_desc.suffix.clone(),
230 row_batch_coeffs,
231 mixing_coeff,
232 )?) as Box<dyn MultivariatePoly<_>>,
233 _ => {
234 return Err(Error::PackingDegreeNotSupported {
235 kappa: suffix_desc.kappa,
236 })
237 }
238 };
239 Ok(eq_ind)
240}