1use std::{iter, sync::Arc};
4
5use binius_field::{
6 tower::{PackedTop, TowerFamily},
7 Field, TowerField,
8};
9use binius_math::{MultilinearExtension, MultilinearQuery};
10use binius_utils::checked_arithmetics::log2_ceil_usize;
11use bytes::Buf;
12use itertools::izip;
13
14use super::eq_ind::RowBatchCoeffs;
15use crate::{
16 fiat_shamir::{CanSample, Challenger},
17 piop::PIOPSumcheckClaim,
18 polynomial::MultivariatePoly,
19 ring_switch::{
20 eq_ind::RingSwitchEqInd, tower_tensor_algebra::TowerTensorAlgebra, Error,
21 EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError,
22 },
23 transcript::{TranscriptReader, VerifierTranscript},
24};
25
26type FExt<Tower> = <Tower as TowerFamily>::B128;
27
28#[derive(Debug)]
29pub struct ReducedClaim<'a, F: Field> {
30 pub transparents: Vec<Box<dyn MultivariatePoly<F> + 'a>>,
31 pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
32}
33
34pub fn verify<'a, F, Tower, Challenger_>(
35 system: &'a EvalClaimSystem<F>,
36 transcript: &mut VerifierTranscript<Challenger_>,
37) -> Result<ReducedClaim<'a, F>, Error>
38where
39 F: TowerField + PackedTop<Tower>,
40 Tower: TowerFamily<B128 = F>,
41 Challenger_: Challenger,
42{
43 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
46 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
47 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
48
49 let tensor_elems =
52 verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?;
53
54 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
56 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
57 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
58 ));
59
60 let row_batched_evals = transcript
62 .message()
63 .read_scalar_slice(system.sumcheck_claim_descs.len())?;
64
65 let mixed_row_batched_evals = accumulate_evaluations_by_prefixes(
68 row_batched_evals.iter().copied(),
69 system.prefix_descs.len(),
70 &system.eval_claim_to_prefix_desc_index,
71 );
72 for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) {
73 if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected {
74 return Err(VerificationError::IncorrectRowBatchedSum.into());
75 }
76 }
77
78 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>(
80 &system.sumcheck_claim_descs,
81 &system.suffix_descs,
82 &row_batch_coeffs,
83 &mixing_coeffs,
84 )?;
85 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
86 .enumerate()
87 .map(|(idx, (claim_desc, eval))| {
88 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
89 PIOPSumcheckClaim {
90 n_vars: suffix_desc.suffix.len(),
91 committed: claim_desc.committed_idx,
92 transparent: idx,
93 sum: eval,
94 }
95 })
96 .collect::<Vec<_>>();
97
98 Ok(ReducedClaim {
99 transparents: ring_switch_eq_inds,
100 sumcheck_claims,
101 })
102}
103
104fn verify_receive_tensor_elems<F, Tower, B>(
105 system: &EvalClaimSystem<F>,
106 mixing_coeffs: &[F],
107 transcript: &mut TranscriptReader<B>,
108) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
109where
110 F: TowerField + PackedTop<Tower>,
111 Tower: TowerFamily<B128 = F>,
112 B: Buf,
113{
114 let expected_tensor_elem_evals = compute_mixed_evaluations(
115 system
116 .sumcheck_claim_descs
117 .iter()
118 .map(|desc| desc.eval_claim.eval),
119 system.prefix_descs.len(),
120 &system.eval_claim_to_prefix_desc_index,
121 mixing_coeffs,
122 );
123
124 let mut tensor_elems = Vec::with_capacity(system.prefix_descs.len());
125 for (desc, expected_eval) in iter::zip(&system.prefix_descs, expected_tensor_elem_evals) {
126 let kappa = desc.kappa();
127 let tensor_elem =
128 TowerTensorAlgebra::new(kappa, transcript.read_scalar_slice(1 << kappa)?)?;
129
130 let query = MultilinearQuery::<F>::expand(&desc.prefix);
131 let tensor_elem_eval =
132 MultilinearExtension::<F, _>::new(kappa, tensor_elem.vertical_elems())
133 .expect("tensor_elem has length 1 << kappa")
134 .evaluate(&query)
135 .expect("query has kappa variables");
136 if tensor_elem_eval != expected_eval {
137 return Err(VerificationError::IncorrectEvaluation.into());
138 }
139
140 tensor_elems.push(tensor_elem);
141 }
142
143 Ok(tensor_elems)
144}
145
146fn compute_mixed_evaluations<F: TowerField>(
147 evals: impl ExactSizeIterator<Item = F>,
148 n_prefixes: usize,
149 eval_claim_to_prefix_desc_index: &[usize],
150 mixing_coeffs: &[F],
151) -> Vec<F> {
152 debug_assert!(evals.len() <= mixing_coeffs.len());
154
155 accumulate_evaluations_by_prefixes(
156 iter::zip(evals, mixing_coeffs).map(|(eval, &mixing_coeff)| eval * mixing_coeff),
157 n_prefixes,
158 eval_claim_to_prefix_desc_index,
159 )
160}
161
162fn accumulate_evaluations_by_prefixes<F: TowerField>(
163 evals: impl ExactSizeIterator<Item = F>,
164 n_prefixes: usize,
165 eval_claim_to_prefix_desc_index: &[usize],
166) -> Vec<F> {
167 debug_assert_eq!(evals.len(), eval_claim_to_prefix_desc_index.len());
169
170 let mut batched_evals = vec![F::ZERO; n_prefixes];
171 for (eval, &desc_index) in izip!(evals, eval_claim_to_prefix_desc_index) {
172 batched_evals[desc_index] += eval;
173 }
174 batched_evals
175}
176
177fn make_ring_switch_eq_inds<F, Tower>(
178 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
179 suffix_descs: &[EvalClaimSuffixDesc<F>],
180 row_batch_coeffs: &Arc<RowBatchCoeffs<F>>,
181 mixing_coeffs: &[F],
182) -> Result<Vec<Box<dyn MultivariatePoly<F>>>, Error>
183where
184 F: TowerField + PackedTop<Tower>,
185 Tower: TowerFamily<B128 = F>,
186{
187 iter::zip(sumcheck_claim_descs, mixing_coeffs)
188 .map(|(claim_desc, &mixing_coeff)| {
189 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
190 make_ring_switch_eq_ind::<Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
191 })
192 .collect()
193}
194
195fn make_ring_switch_eq_ind<Tower>(
196 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
197 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
198 mixing_coeff: FExt<Tower>,
199) -> Result<Box<dyn MultivariatePoly<FExt<Tower>>>, Error>
200where
201 Tower: TowerFamily,
202 FExt<Tower>: PackedTop<Tower>,
203{
204 let eq_ind = match suffix_desc.kappa {
205 7 => Box::new(RingSwitchEqInd::<Tower::B1, _>::new(
206 suffix_desc.suffix.clone(),
207 row_batch_coeffs,
208 mixing_coeff,
209 )?) as Box<dyn MultivariatePoly<_>>,
210 4 => Box::new(RingSwitchEqInd::<Tower::B8, _>::new(
211 suffix_desc.suffix.clone(),
212 row_batch_coeffs,
213 mixing_coeff,
214 )?) as Box<dyn MultivariatePoly<_>>,
215 3 => Box::new(RingSwitchEqInd::<Tower::B16, _>::new(
216 suffix_desc.suffix.clone(),
217 row_batch_coeffs,
218 mixing_coeff,
219 )?) as Box<dyn MultivariatePoly<_>>,
220 2 => Box::new(RingSwitchEqInd::<Tower::B32, _>::new(
221 suffix_desc.suffix.clone(),
222 row_batch_coeffs,
223 mixing_coeff,
224 )?) as Box<dyn MultivariatePoly<_>>,
225 1 => Box::new(RingSwitchEqInd::<Tower::B64, _>::new(
226 suffix_desc.suffix.clone(),
227 row_batch_coeffs,
228 mixing_coeff,
229 )?) as Box<dyn MultivariatePoly<_>>,
230 0 => Box::new(RingSwitchEqInd::<Tower::B128, _>::new(
231 suffix_desc.suffix.clone(),
232 row_batch_coeffs,
233 mixing_coeff,
234 )?) as Box<dyn MultivariatePoly<_>>,
235 _ => {
236 return Err(Error::PackingDegreeNotSupported {
237 kappa: suffix_desc.kappa,
238 })
239 }
240 };
241 Ok(eq_ind)
242}