1use std::iter;
4
5use binius_field::{BinaryField, ExtensionField, TowerField};
6use binius_hal::{ComputationBackend, make_portable_backend};
7use binius_ntt::{
8 SingleThreadedNTT,
9 fri::{fold_chunk, fold_interleaved_chunk},
10};
11use binius_utils::{DeserializeBytes, bail};
12use bytes::Buf;
13use itertools::izip;
14use tracing::instrument;
15
16use super::{VerificationError, common::vcs_optimal_layers_depths_iter, error::Error};
17use crate::{
18 fiat_shamir::{CanSampleBits, Challenger},
19 merkle_tree::MerkleTreeScheme,
20 protocols::fri::common::FRIParams,
21 transcript::{TranscriptReader, VerifierTranscript},
22};
23
24#[derive(Debug)]
29pub struct FRIVerifier<'a, F, FA, VCS>
30where
31 F: BinaryField + ExtensionField<FA>,
32 FA: BinaryField,
33 VCS: MerkleTreeScheme<F>,
34{
35 vcs: &'a VCS,
36 params: &'a FRIParams<F, FA>,
37 codeword_commitment: &'a VCS::Digest,
39 round_commitments: &'a [VCS::Digest],
41 interleave_tensor: Vec<F>,
43 fold_challenges: &'a [F],
45}
46
47impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS>
48where
49 F: TowerField + ExtensionField<FA>,
50 FA: BinaryField,
51 VCS: MerkleTreeScheme<F, Digest: DeserializeBytes>,
52{
53 #[allow(clippy::too_many_arguments)]
54 pub fn new(
55 params: &'a FRIParams<F, FA>,
56 vcs: &'a VCS,
57 codeword_commitment: &'a VCS::Digest,
58 round_commitments: &'a [VCS::Digest],
59 challenges: &'a [F],
60 ) -> Result<Self, Error> {
61 if round_commitments.len() != params.n_oracles() {
62 bail!(Error::InvalidArgs(format!(
63 "got {} round commitments, expected {}",
64 round_commitments.len(),
65 params.n_oracles(),
66 )));
67 }
68
69 if challenges.len() != params.n_fold_rounds() {
70 bail!(Error::InvalidArgs(format!(
71 "got {} folding challenges, expected {}",
72 challenges.len(),
73 params.n_fold_rounds(),
74 )));
75 }
76
77 let (interleave_challenges, fold_challenges) = challenges.split_at(params.log_batch_size());
78
79 let backend = make_portable_backend();
80 let interleave_tensor = backend
81 .tensor_product_full_query(interleave_challenges)
82 .expect("number of challenges is less than 32");
83
84 Ok(Self {
85 params,
86 vcs,
87 codeword_commitment,
88 round_commitments,
89 interleave_tensor,
90 fold_challenges,
91 })
92 }
93
94 pub fn n_oracles(&self) -> usize {
96 self.params.n_oracles()
97 }
98
99 pub fn verify<Challenger_>(
100 &self,
101 transcript: &mut VerifierTranscript<Challenger_>,
102 ) -> Result<F, Error>
103 where
104 Challenger_: Challenger,
105 {
106 let ntt = SingleThreadedNTT::with_subspace(self.params.rs_code().subspace())?;
107
108 let terminate_codeword_len =
110 1 << (self.params.n_final_challenges() + self.params.rs_code().log_inv_rate());
111 let mut advice = transcript.decommitment();
112 let terminate_codeword = advice
113 .read_scalar_slice(terminate_codeword_len)
114 .map_err(Error::TranscriptError)?;
115 let final_value = self.verify_last_oracle(&ntt, &terminate_codeword)?;
116
117 let layers = vcs_optimal_layers_depths_iter(self.params, self.vcs)
119 .map(|layer_depth| advice.read_vec(1 << layer_depth))
120 .collect::<Result<Vec<_>, _>>()?;
121 for (commitment, layer_depth, layer) in izip!(
122 iter::once(self.codeword_commitment).chain(self.round_commitments),
123 vcs_optimal_layers_depths_iter(self.params, self.vcs),
124 &layers
125 ) {
126 self.vcs
127 .verify_layer(commitment, layer_depth, layer)
128 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
129 }
130
131 let mut scratch_buffer = self.create_scratch_buffer();
134 for _ in 0..self.params.n_test_queries() {
135 let index = transcript.sample_bits(self.params.index_bits()) as usize;
136 self.verify_query_internal(
137 index,
138 &ntt,
139 &terminate_codeword,
140 &layers,
141 &mut transcript.decommitment(),
142 &mut scratch_buffer,
143 )?
144 }
145
146 Ok(final_value)
147 }
148
149 pub fn verify_last_oracle(
153 &self,
154 ntt: &SingleThreadedNTT<FA>,
155 terminate_codeword: &[F],
156 ) -> Result<F, Error> {
157 let n_final_challenges = self.params.n_final_challenges();
158
159 self.vcs
160 .verify_vector(
161 self.round_commitments
162 .last()
163 .unwrap_or(self.codeword_commitment),
164 terminate_codeword,
165 1 << n_final_challenges,
166 )
167 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
168
169 let repetition_codeword = if self.n_oracles() != 0 {
170 let n_prior_challenges = self.fold_challenges.len() - n_final_challenges;
171 let final_challenges = &self.fold_challenges[n_prior_challenges..];
172 let mut scratch_buffer = vec![F::default(); 1 << n_final_challenges];
173
174 terminate_codeword
175 .chunks(1 << n_final_challenges)
176 .enumerate()
177 .map(|(i, coset_values)| {
178 scratch_buffer.copy_from_slice(coset_values);
179 fold_chunk(
180 ntt,
181 n_final_challenges + self.params.rs_code().log_inv_rate(),
182 i,
183 &mut scratch_buffer,
184 final_challenges,
185 )
186 })
187 .collect::<Vec<_>>()
188 } else {
189 let fold_arity = self.params.rs_code().log_dim() + self.params.log_batch_size();
193 let mut scratch_buffer = vec![F::default(); 1 << self.params.rs_code().log_dim()];
194 terminate_codeword
195 .chunks(1 << fold_arity)
196 .enumerate()
197 .map(|(i, chunk)| {
198 fold_interleaved_chunk(
199 ntt,
200 self.params.rs_code().log_len(),
201 self.params.log_batch_size(),
202 i,
203 chunk,
204 &self.interleave_tensor,
205 self.fold_challenges,
206 &mut scratch_buffer,
207 )
208 })
209 .collect::<Vec<_>>()
210 };
211
212 let final_value = repetition_codeword[0];
213
214 if repetition_codeword[1..]
216 .iter()
217 .any(|&entry| entry != final_value)
218 {
219 return Err(VerificationError::IncorrectDegree.into());
220 }
221
222 Ok(final_value)
223 }
224
225 pub fn verify_query<B: Buf>(
236 &self,
237 index: usize,
238 ntt: &SingleThreadedNTT<FA>,
239 terminate_codeword: &[F],
240 layers: &[Vec<VCS::Digest>],
241 advice: &mut TranscriptReader<B>,
242 ) -> Result<(), Error> {
243 self.verify_query_internal(
244 index,
245 ntt,
246 terminate_codeword,
247 layers,
248 advice,
249 &mut self.create_scratch_buffer(),
250 )
251 }
252
253 #[instrument(skip_all, name = "fri::FRIVerifier::verify_query", level = "debug")]
254 fn verify_query_internal<B: Buf>(
255 &self,
256 mut index: usize,
257 ntt: &SingleThreadedNTT<FA>,
258 terminate_codeword: &[F],
259 layers: &[Vec<VCS::Digest>],
260 advice: &mut TranscriptReader<B>,
261 scratch_buffer: &mut [F],
262 ) -> Result<(), Error> {
263 let mut arities_iter = self.params.fold_arities().iter().copied();
264
265 let mut layer_digest_and_optimal_layer_depth =
266 iter::zip(layers, vcs_optimal_layers_depths_iter(self.params, self.vcs));
267
268 let Some(first_fold_arity) = arities_iter.next() else {
269 return Ok(());
273 };
274
275 let (first_layer, first_optimal_layer_depth) = layer_digest_and_optimal_layer_depth
276 .next()
277 .expect("The length should be the same as the amount of proofs.");
278
279 let mut fold_round = 0;
281 let mut log_n_cosets = self.params.index_bits();
282
283 let log_coset_size = first_fold_arity - self.params.log_batch_size();
286 let values = verify_coset_opening(
287 self.vcs,
288 index,
289 first_fold_arity,
290 first_optimal_layer_depth,
291 log_n_cosets,
292 first_layer,
293 advice,
294 )?;
295 let mut next_value = fold_interleaved_chunk(
296 ntt,
297 self.params.rs_code().log_len(),
298 self.params.log_batch_size(),
299 index,
300 &values,
301 &self.interleave_tensor,
302 &self.fold_challenges[fold_round..fold_round + log_coset_size],
303 scratch_buffer,
304 );
305 fold_round += log_coset_size;
306
307 for (i, (arity, (layer, optimal_layer_depth))) in
308 izip!(arities_iter, layer_digest_and_optimal_layer_depth).enumerate()
309 {
310 let coset_index = index >> arity;
311
312 log_n_cosets -= arity;
313
314 let mut values = verify_coset_opening(
315 self.vcs,
316 coset_index,
317 arity,
318 optimal_layer_depth,
319 log_n_cosets,
320 layer,
321 advice,
322 )?;
323
324 if next_value != values[index % (1 << arity)] {
325 return Err(VerificationError::IncorrectFold {
326 query_round: i,
327 index,
328 }
329 .into());
330 }
331
332 next_value = fold_chunk(
333 ntt,
334 self.params.rs_code().log_len() - fold_round,
335 coset_index,
336 &mut values,
337 &self.fold_challenges[fold_round..fold_round + arity],
338 );
339 index = coset_index;
340 fold_round += arity;
341 }
342
343 if next_value != terminate_codeword[index] {
344 return Err(VerificationError::IncorrectFold {
345 query_round: self.n_oracles() - 1,
346 index,
347 }
348 .into());
349 }
350
351 Ok(())
352 }
353
354 fn create_scratch_buffer(&self) -> Vec<F> {
356 let max_arity = self
357 .params
358 .fold_arities()
359 .iter()
360 .copied()
361 .max()
362 .unwrap_or_default();
363 let max_buffer_size = 2 * (1 << max_arity);
364 vec![F::default(); max_buffer_size]
365 }
366}
367
368#[allow(clippy::too_many_arguments)]
370fn verify_coset_opening<F, MTScheme, B>(
371 vcs: &MTScheme,
372 coset_index: usize,
373 log_coset_size: usize,
374 optimal_layer_depth: usize,
375 tree_depth: usize,
376 layer_digests: &[MTScheme::Digest],
377 advice: &mut TranscriptReader<B>,
378) -> Result<Vec<F>, Error>
379where
380 F: TowerField,
381 MTScheme: MerkleTreeScheme<F>,
382 B: Buf,
383{
384 let values = advice.read_scalar_slice::<F>(1 << log_coset_size)?;
385 vcs.verify_opening(
386 coset_index,
387 &values,
388 optimal_layer_depth,
389 tree_depth,
390 layer_digests,
391 advice,
392 )
393 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
394
395 Ok(values)
396}