1use std::iter;
4
5use binius_field::{BinaryField, ExtensionField, TowerField};
6use binius_hal::{make_portable_backend, ComputationBackend};
7use binius_utils::{bail, DeserializeBytes};
8use bytes::Buf;
9use itertools::izip;
10use tracing::instrument;
11
12use super::{common::vcs_optimal_layers_depths_iter, error::Error, VerificationError};
13use crate::{
14 fiat_shamir::{CanSampleBits, Challenger},
15 merkle_tree::MerkleTreeScheme,
16 protocols::fri::common::{fold_chunk, fold_interleaved_chunk, FRIParams},
17 transcript::{TranscriptReader, VerifierTranscript},
18};
19
20#[derive(Debug)]
25pub struct FRIVerifier<'a, F, FA, VCS>
26where
27 F: BinaryField + ExtensionField<FA>,
28 FA: BinaryField,
29 VCS: MerkleTreeScheme<F>,
30{
31 vcs: &'a VCS,
32 params: &'a FRIParams<F, FA>,
33 codeword_commitment: &'a VCS::Digest,
35 round_commitments: &'a [VCS::Digest],
37 interleave_tensor: Vec<F>,
39 fold_challenges: &'a [F],
41}
42
43impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS>
44where
45 F: TowerField + ExtensionField<FA>,
46 FA: BinaryField,
47 VCS: MerkleTreeScheme<F, Digest: DeserializeBytes>,
48{
49 #[allow(clippy::too_many_arguments)]
50 pub fn new(
51 params: &'a FRIParams<F, FA>,
52 vcs: &'a VCS,
53 codeword_commitment: &'a VCS::Digest,
54 round_commitments: &'a [VCS::Digest],
55 challenges: &'a [F],
56 ) -> Result<Self, Error> {
57 if round_commitments.len() != params.n_oracles() {
58 bail!(Error::InvalidArgs(format!(
59 "got {} round commitments, expected {}",
60 round_commitments.len(),
61 params.n_oracles(),
62 )));
63 }
64
65 if challenges.len() != params.n_fold_rounds() {
66 bail!(Error::InvalidArgs(format!(
67 "got {} folding challenges, expected {}",
68 challenges.len(),
69 params.n_fold_rounds(),
70 )));
71 }
72
73 let (interleave_challenges, fold_challenges) = challenges.split_at(params.log_batch_size());
74
75 let backend = make_portable_backend();
76 let interleave_tensor = backend
77 .tensor_product_full_query(interleave_challenges)
78 .expect("number of challenges is less than 32");
79
80 Ok(Self {
81 params,
82 vcs,
83 codeword_commitment,
84 round_commitments,
85 interleave_tensor,
86 fold_challenges,
87 })
88 }
89
90 pub fn n_oracles(&self) -> usize {
92 self.params.n_oracles()
93 }
94
95 pub fn verify<Challenger_>(
96 &self,
97 transcript: &mut VerifierTranscript<Challenger_>,
98 ) -> Result<F, Error>
99 where
100 Challenger_: Challenger,
101 {
102 let terminate_codeword_len =
104 1 << (self.params.n_final_challenges() + self.params.rs_code().log_inv_rate());
105 let mut advice = transcript.decommitment();
106 let terminate_codeword = advice
107 .read_scalar_slice(terminate_codeword_len)
108 .map_err(Error::TranscriptError)?;
109 let final_value = self.verify_last_oracle(&terminate_codeword)?;
110
111 let layers = vcs_optimal_layers_depths_iter(self.params, self.vcs)
113 .map(|layer_depth| advice.read_vec(1 << layer_depth))
114 .collect::<Result<Vec<_>, _>>()?;
115 for (commitment, layer_depth, layer) in izip!(
116 iter::once(self.codeword_commitment).chain(self.round_commitments),
117 vcs_optimal_layers_depths_iter(self.params, self.vcs),
118 &layers
119 ) {
120 self.vcs
121 .verify_layer(commitment, layer_depth, layer)
122 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
123 }
124
125 let mut scratch_buffer = self.create_scratch_buffer();
128 for _ in 0..self.params.n_test_queries() {
129 let index = transcript.sample_bits(self.params.index_bits());
130 self.verify_query_internal(
131 index,
132 &terminate_codeword,
133 &layers,
134 &mut transcript.decommitment(),
135 &mut scratch_buffer,
136 )?
137 }
138
139 Ok(final_value)
140 }
141
142 pub fn verify_last_oracle(&self, terminate_codeword: &[F]) -> Result<F, Error> {
146 self.vcs
147 .verify_vector(
148 self.round_commitments
149 .last()
150 .unwrap_or(self.codeword_commitment),
151 terminate_codeword,
152 1 << self.params.rs_code().log_inv_rate(),
153 )
154 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
155
156 let repetition_codeword = if self.n_oracles() != 0 {
157 let n_final_challenges = self.params.n_final_challenges();
158 let n_prior_challenges = self.fold_challenges.len() - n_final_challenges;
159 let final_challenges = &self.fold_challenges[n_prior_challenges..];
160 let mut scratch_buffer = vec![F::default(); 1 << n_final_challenges];
161
162 terminate_codeword
163 .chunks(1 << n_final_challenges)
164 .enumerate()
165 .map(|(i, coset_values)| {
166 fold_chunk(
167 self.params.rs_code(),
168 n_prior_challenges,
169 i,
170 coset_values,
171 final_challenges,
172 &mut scratch_buffer,
173 )
174 })
175 .collect::<Vec<_>>()
176 } else {
177 let fold_arity = self.params.rs_code().log_dim() + self.params.log_batch_size();
181 let mut scratch_buffer = vec![F::default(); 2 * (1 << fold_arity)];
182 terminate_codeword
183 .chunks(1 << fold_arity)
184 .enumerate()
185 .map(|(i, chunk)| {
186 fold_interleaved_chunk(
187 self.params.rs_code(),
188 self.params.log_batch_size(),
189 i,
190 chunk,
191 &self.interleave_tensor,
192 self.fold_challenges,
193 &mut scratch_buffer,
194 )
195 })
196 .collect::<Vec<_>>()
197 };
198
199 let final_value = repetition_codeword[0];
200
201 if repetition_codeword[1..]
203 .iter()
204 .any(|&entry| entry != final_value)
205 {
206 return Err(VerificationError::IncorrectDegree.into());
207 }
208
209 Ok(final_value)
210 }
211
212 pub fn verify_query<B: Buf>(
223 &self,
224 index: usize,
225 terminate_codeword: &[F],
226 layers: &[Vec<VCS::Digest>],
227 advice: &mut TranscriptReader<B>,
228 ) -> Result<(), Error> {
229 self.verify_query_internal(
230 index,
231 terminate_codeword,
232 layers,
233 advice,
234 &mut self.create_scratch_buffer(),
235 )
236 }
237
238 #[instrument(skip_all, name = "fri::FRIVerifier::verify_query", level = "debug")]
239 fn verify_query_internal<B: Buf>(
240 &self,
241 mut index: usize,
242 terminate_codeword: &[F],
243 layers: &[Vec<VCS::Digest>],
244 advice: &mut TranscriptReader<B>,
245 scratch_buffer: &mut [F],
246 ) -> Result<(), Error> {
247 let mut arities_iter = self.params.fold_arities().iter().copied();
248
249 let mut layer_digest_and_optimal_layer_depth =
250 iter::zip(layers, vcs_optimal_layers_depths_iter(self.params, self.vcs));
251
252 let Some(first_fold_arity) = arities_iter.next() else {
253 return Ok(());
257 };
258
259 let (first_layer, first_optimal_layer_depth) = layer_digest_and_optimal_layer_depth
260 .next()
261 .expect("The length should be the same as the amount of proofs.");
262
263 let mut fold_round = 0;
265 let mut log_n_cosets = self.params.index_bits();
266
267 let log_coset_size = first_fold_arity - self.params.log_batch_size();
270 let values = verify_coset_opening(
271 self.vcs,
272 index,
273 first_fold_arity,
274 first_optimal_layer_depth,
275 log_n_cosets,
276 first_layer,
277 advice,
278 )?;
279 let mut next_value = fold_interleaved_chunk(
280 self.params.rs_code(),
281 self.params.log_batch_size(),
282 index,
283 &values,
284 &self.interleave_tensor,
285 &self.fold_challenges[fold_round..fold_round + log_coset_size],
286 scratch_buffer,
287 );
288 fold_round += log_coset_size;
289
290 for (i, (arity, (layer, optimal_layer_depth))) in
291 izip!(arities_iter, layer_digest_and_optimal_layer_depth).enumerate()
292 {
293 let coset_index = index >> arity;
294
295 log_n_cosets -= arity;
296
297 let values = verify_coset_opening(
298 self.vcs,
299 coset_index,
300 arity,
301 optimal_layer_depth,
302 log_n_cosets,
303 layer,
304 advice,
305 )?;
306
307 if next_value != values[index % (1 << arity)] {
308 return Err(VerificationError::IncorrectFold {
309 query_round: i,
310 index,
311 }
312 .into());
313 }
314
315 next_value = fold_chunk(
316 self.params.rs_code(),
317 fold_round,
318 coset_index,
319 &values,
320 &self.fold_challenges[fold_round..fold_round + arity],
321 scratch_buffer,
322 );
323 index = coset_index;
324 fold_round += arity;
325 }
326
327 if next_value != terminate_codeword[index] {
328 return Err(VerificationError::IncorrectFold {
329 query_round: self.n_oracles() - 1,
330 index,
331 }
332 .into());
333 }
334
335 Ok(())
336 }
337
338 fn create_scratch_buffer(&self) -> Vec<F> {
340 let max_arity = self
341 .params
342 .fold_arities()
343 .iter()
344 .copied()
345 .max()
346 .unwrap_or_default();
347 let max_buffer_size = 2 * (1 << max_arity);
348 vec![F::default(); max_buffer_size]
349 }
350}
351
352#[allow(clippy::too_many_arguments)]
354fn verify_coset_opening<F, MTScheme, B>(
355 vcs: &MTScheme,
356 coset_index: usize,
357 log_coset_size: usize,
358 optimal_layer_depth: usize,
359 tree_depth: usize,
360 layer_digests: &[MTScheme::Digest],
361 advice: &mut TranscriptReader<B>,
362) -> Result<Vec<F>, Error>
363where
364 F: TowerField,
365 MTScheme: MerkleTreeScheme<F>,
366 B: Buf,
367{
368 let values = advice.read_scalar_slice::<F>(1 << log_coset_size)?;
369 vcs.verify_opening(
370 coset_index,
371 &values,
372 optimal_layer_depth,
373 tree_depth,
374 layer_digests,
375 advice,
376 )
377 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
378
379 Ok(values)
380}