1use std::iter;
4
5use binius_field::{BinaryField, ExtensionField, TowerField};
6use binius_hal::{make_portable_backend, ComputationBackend};
7use binius_ntt::SingleThreadedNTT;
8use binius_utils::{bail, DeserializeBytes};
9use bytes::Buf;
10use itertools::izip;
11use tracing::instrument;
12
13use super::{common::vcs_optimal_layers_depths_iter, error::Error, VerificationError};
14use crate::{
15 fiat_shamir::{CanSampleBits, Challenger},
16 merkle_tree::MerkleTreeScheme,
17 protocols::fri::common::{fold_chunk, fold_interleaved_chunk, FRIParams},
18 transcript::{TranscriptReader, VerifierTranscript},
19};
20
21#[derive(Debug)]
26pub struct FRIVerifier<'a, F, FA, VCS>
27where
28 F: BinaryField + ExtensionField<FA>,
29 FA: BinaryField,
30 VCS: MerkleTreeScheme<F>,
31{
32 vcs: &'a VCS,
33 params: &'a FRIParams<F, FA>,
34 codeword_commitment: &'a VCS::Digest,
36 round_commitments: &'a [VCS::Digest],
38 interleave_tensor: Vec<F>,
40 fold_challenges: &'a [F],
42}
43
44impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS>
45where
46 F: TowerField + ExtensionField<FA>,
47 FA: BinaryField,
48 VCS: MerkleTreeScheme<F, Digest: DeserializeBytes>,
49{
50 #[allow(clippy::too_many_arguments)]
51 pub fn new(
52 params: &'a FRIParams<F, FA>,
53 vcs: &'a VCS,
54 codeword_commitment: &'a VCS::Digest,
55 round_commitments: &'a [VCS::Digest],
56 challenges: &'a [F],
57 ) -> Result<Self, Error> {
58 if round_commitments.len() != params.n_oracles() {
59 bail!(Error::InvalidArgs(format!(
60 "got {} round commitments, expected {}",
61 round_commitments.len(),
62 params.n_oracles(),
63 )));
64 }
65
66 if challenges.len() != params.n_fold_rounds() {
67 bail!(Error::InvalidArgs(format!(
68 "got {} folding challenges, expected {}",
69 challenges.len(),
70 params.n_fold_rounds(),
71 )));
72 }
73
74 let (interleave_challenges, fold_challenges) = challenges.split_at(params.log_batch_size());
75
76 let backend = make_portable_backend();
77 let interleave_tensor = backend
78 .tensor_product_full_query(interleave_challenges)
79 .expect("number of challenges is less than 32");
80
81 Ok(Self {
82 params,
83 vcs,
84 codeword_commitment,
85 round_commitments,
86 interleave_tensor,
87 fold_challenges,
88 })
89 }
90
91 pub fn n_oracles(&self) -> usize {
93 self.params.n_oracles()
94 }
95
96 pub fn verify<Challenger_>(
97 &self,
98 transcript: &mut VerifierTranscript<Challenger_>,
99 ) -> Result<F, Error>
100 where
101 Challenger_: Challenger,
102 {
103 let ntt = SingleThreadedNTT::with_subspace(self.params.rs_code().subspace())?;
104
105 let terminate_codeword_len =
107 1 << (self.params.n_final_challenges() + self.params.rs_code().log_inv_rate());
108 let mut advice = transcript.decommitment();
109 let terminate_codeword = advice
110 .read_scalar_slice(terminate_codeword_len)
111 .map_err(Error::TranscriptError)?;
112 let final_value = self.verify_last_oracle(&ntt, &terminate_codeword)?;
113
114 let layers = vcs_optimal_layers_depths_iter(self.params, self.vcs)
116 .map(|layer_depth| advice.read_vec(1 << layer_depth))
117 .collect::<Result<Vec<_>, _>>()?;
118 for (commitment, layer_depth, layer) in izip!(
119 iter::once(self.codeword_commitment).chain(self.round_commitments),
120 vcs_optimal_layers_depths_iter(self.params, self.vcs),
121 &layers
122 ) {
123 self.vcs
124 .verify_layer(commitment, layer_depth, layer)
125 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
126 }
127
128 let mut scratch_buffer = self.create_scratch_buffer();
131 for _ in 0..self.params.n_test_queries() {
132 let index = transcript.sample_bits(self.params.index_bits()) as usize;
133 self.verify_query_internal(
134 index,
135 &ntt,
136 &terminate_codeword,
137 &layers,
138 &mut transcript.decommitment(),
139 &mut scratch_buffer,
140 )?
141 }
142
143 Ok(final_value)
144 }
145
146 pub fn verify_last_oracle(
150 &self,
151 ntt: &SingleThreadedNTT<FA>,
152 terminate_codeword: &[F],
153 ) -> Result<F, Error> {
154 let n_final_challenges = self.params.n_final_challenges();
155
156 self.vcs
157 .verify_vector(
158 self.round_commitments
159 .last()
160 .unwrap_or(self.codeword_commitment),
161 terminate_codeword,
162 1 << n_final_challenges,
163 )
164 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
165
166 let repetition_codeword = if self.n_oracles() != 0 {
167 let n_prior_challenges = self.fold_challenges.len() - n_final_challenges;
168 let final_challenges = &self.fold_challenges[n_prior_challenges..];
169 let mut scratch_buffer = vec![F::default(); 1 << n_final_challenges];
170
171 terminate_codeword
172 .chunks(1 << n_final_challenges)
173 .enumerate()
174 .map(|(i, coset_values)| {
175 scratch_buffer.copy_from_slice(coset_values);
176 fold_chunk(
177 ntt,
178 n_final_challenges + self.params.rs_code().log_inv_rate(),
179 i,
180 &mut scratch_buffer,
181 final_challenges,
182 )
183 })
184 .collect::<Vec<_>>()
185 } else {
186 let fold_arity = self.params.rs_code().log_dim() + self.params.log_batch_size();
190 let mut scratch_buffer = vec![F::default(); 1 << self.params.rs_code().log_dim()];
191 terminate_codeword
192 .chunks(1 << fold_arity)
193 .enumerate()
194 .map(|(i, chunk)| {
195 fold_interleaved_chunk(
196 ntt,
197 self.params.rs_code().log_len(),
198 self.params.log_batch_size(),
199 i,
200 chunk,
201 &self.interleave_tensor,
202 self.fold_challenges,
203 &mut scratch_buffer,
204 )
205 })
206 .collect::<Vec<_>>()
207 };
208
209 let final_value = repetition_codeword[0];
210
211 if repetition_codeword[1..]
213 .iter()
214 .any(|&entry| entry != final_value)
215 {
216 return Err(VerificationError::IncorrectDegree.into());
217 }
218
219 Ok(final_value)
220 }
221
222 pub fn verify_query<B: Buf>(
233 &self,
234 index: usize,
235 ntt: &SingleThreadedNTT<FA>,
236 terminate_codeword: &[F],
237 layers: &[Vec<VCS::Digest>],
238 advice: &mut TranscriptReader<B>,
239 ) -> Result<(), Error> {
240 self.verify_query_internal(
241 index,
242 ntt,
243 terminate_codeword,
244 layers,
245 advice,
246 &mut self.create_scratch_buffer(),
247 )
248 }
249
250 #[instrument(skip_all, name = "fri::FRIVerifier::verify_query", level = "debug")]
251 fn verify_query_internal<B: Buf>(
252 &self,
253 mut index: usize,
254 ntt: &SingleThreadedNTT<FA>,
255 terminate_codeword: &[F],
256 layers: &[Vec<VCS::Digest>],
257 advice: &mut TranscriptReader<B>,
258 scratch_buffer: &mut [F],
259 ) -> Result<(), Error> {
260 let mut arities_iter = self.params.fold_arities().iter().copied();
261
262 let mut layer_digest_and_optimal_layer_depth =
263 iter::zip(layers, vcs_optimal_layers_depths_iter(self.params, self.vcs));
264
265 let Some(first_fold_arity) = arities_iter.next() else {
266 return Ok(());
270 };
271
272 let (first_layer, first_optimal_layer_depth) = layer_digest_and_optimal_layer_depth
273 .next()
274 .expect("The length should be the same as the amount of proofs.");
275
276 let mut fold_round = 0;
278 let mut log_n_cosets = self.params.index_bits();
279
280 let log_coset_size = first_fold_arity - self.params.log_batch_size();
283 let values = verify_coset_opening(
284 self.vcs,
285 index,
286 first_fold_arity,
287 first_optimal_layer_depth,
288 log_n_cosets,
289 first_layer,
290 advice,
291 )?;
292 let mut next_value = fold_interleaved_chunk(
293 ntt,
294 self.params.rs_code().log_len(),
295 self.params.log_batch_size(),
296 index,
297 &values,
298 &self.interleave_tensor,
299 &self.fold_challenges[fold_round..fold_round + log_coset_size],
300 scratch_buffer,
301 );
302 fold_round += log_coset_size;
303
304 for (i, (arity, (layer, optimal_layer_depth))) in
305 izip!(arities_iter, layer_digest_and_optimal_layer_depth).enumerate()
306 {
307 let coset_index = index >> arity;
308
309 log_n_cosets -= arity;
310
311 let mut values = verify_coset_opening(
312 self.vcs,
313 coset_index,
314 arity,
315 optimal_layer_depth,
316 log_n_cosets,
317 layer,
318 advice,
319 )?;
320
321 if next_value != values[index % (1 << arity)] {
322 return Err(VerificationError::IncorrectFold {
323 query_round: i,
324 index,
325 }
326 .into());
327 }
328
329 next_value = fold_chunk(
330 ntt,
331 self.params.rs_code().log_len() - fold_round,
332 coset_index,
333 &mut values,
334 &self.fold_challenges[fold_round..fold_round + arity],
335 );
336 index = coset_index;
337 fold_round += arity;
338 }
339
340 if next_value != terminate_codeword[index] {
341 return Err(VerificationError::IncorrectFold {
342 query_round: self.n_oracles() - 1,
343 index,
344 }
345 .into());
346 }
347
348 Ok(())
349 }
350
351 fn create_scratch_buffer(&self) -> Vec<F> {
353 let max_arity = self
354 .params
355 .fold_arities()
356 .iter()
357 .copied()
358 .max()
359 .unwrap_or_default();
360 let max_buffer_size = 2 * (1 << max_arity);
361 vec![F::default(); max_buffer_size]
362 }
363}
364
365#[allow(clippy::too_many_arguments)]
367fn verify_coset_opening<F, MTScheme, B>(
368 vcs: &MTScheme,
369 coset_index: usize,
370 log_coset_size: usize,
371 optimal_layer_depth: usize,
372 tree_depth: usize,
373 layer_digests: &[MTScheme::Digest],
374 advice: &mut TranscriptReader<B>,
375) -> Result<Vec<F>, Error>
376where
377 F: TowerField,
378 MTScheme: MerkleTreeScheme<F>,
379 B: Buf,
380{
381 let values = advice.read_scalar_slice::<F>(1 << log_coset_size)?;
382 vcs.verify_opening(
383 coset_index,
384 &values,
385 optimal_layer_depth,
386 tree_depth,
387 layer_digests,
388 advice,
389 )
390 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
391
392 Ok(values)
393}