use std::{iter, iter::repeat_with};
use binius_field::{BinaryField, ExtensionField, TowerField};
use binius_hal::{make_portable_backend, ComputationBackend};
use binius_utils::{bail, serialization::DeserializeBytes};
use itertools::izip;
use tracing::instrument;
use super::{common::vcs_optimal_layers_depths_iter, error::Error, VerificationError};
use crate::{
fiat_shamir::CanSampleBits,
merkle_tree::MerkleTreeScheme,
protocols::fri::common::{fold_chunk, fold_interleaved_chunk, FRIParams},
transcript::CanRead,
};
#[derive(Debug)]
pub struct FRIVerifier<'a, F, FA, VCS>
where
F: BinaryField + ExtensionField<FA>,
FA: BinaryField,
VCS: MerkleTreeScheme<F>,
{
vcs: &'a VCS,
params: &'a FRIParams<F, FA>,
codeword_commitment: &'a VCS::Digest,
round_commitments: &'a [VCS::Digest],
interleave_tensor: Vec<F>,
fold_challenges: &'a [F],
}
impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS>
where
F: TowerField + ExtensionField<FA>,
FA: BinaryField,
VCS: MerkleTreeScheme<F, Digest: DeserializeBytes>,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
params: &'a FRIParams<F, FA>,
vcs: &'a VCS,
codeword_commitment: &'a VCS::Digest,
round_commitments: &'a [VCS::Digest],
challenges: &'a [F],
) -> Result<Self, Error> {
if round_commitments.len() != params.n_oracles() {
bail!(Error::InvalidArgs(format!(
"got {} round commitments, expected {}",
round_commitments.len(),
params.n_oracles(),
)));
}
if challenges.len() != params.n_fold_rounds() {
bail!(Error::InvalidArgs(format!(
"got {} folding challenges, expected {}",
challenges.len(),
params.n_fold_rounds(),
)));
}
let (interleave_challenges, fold_challenges) = challenges.split_at(params.log_batch_size());
let backend = make_portable_backend();
let interleave_tensor = backend
.tensor_product_full_query(interleave_challenges)
.expect("number of challenges is less than 32");
Ok(Self {
params,
vcs,
codeword_commitment,
round_commitments,
interleave_tensor,
fold_challenges,
})
}
pub fn n_oracles(&self) -> usize {
self.params.n_oracles()
}
pub fn verify<Transcript, Advice>(
&self,
advice: &mut Advice,
transcript: &mut Transcript,
) -> Result<F, Error>
where
Transcript: CanSampleBits<usize>,
Advice: CanRead,
{
let terminate_codeword_len =
1 << (self.params.n_final_challenges() + self.params.rs_code().log_inv_rate());
let terminate_codeword = advice
.read_scalar_slice(terminate_codeword_len)
.map_err(Error::TranscriptError)?;
let final_value = self.verify_last_oracle(&terminate_codeword)?;
let layers = vcs_optimal_layers_depths_iter(self.params, self.vcs)
.map(|layer_depth| advice.read_vec(1 << layer_depth))
.collect::<Result<Vec<_>, _>>()?;
for (commitment, layer_depth, layer) in izip!(
iter::once(self.codeword_commitment).chain(self.round_commitments),
vcs_optimal_layers_depths_iter(self.params, self.vcs),
&layers
) {
self.vcs
.verify_layer(commitment, layer_depth, layer)
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
}
let indexes_iter = repeat_with(|| transcript.sample_bits(self.params.index_bits()))
.take(self.params.n_test_queries());
let mut scratch_buffer = self.create_scratch_buffer();
for index in indexes_iter {
self.verify_query_internal(
index,
&terminate_codeword,
&layers,
advice,
&mut scratch_buffer,
)?
}
Ok(final_value)
}
pub fn verify_last_oracle(&self, terminate_codeword: &[F]) -> Result<F, Error> {
self.vcs
.verify_vector(
self.round_commitments
.last()
.unwrap_or(self.codeword_commitment),
terminate_codeword,
1 << self.params.rs_code().log_inv_rate(),
)
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
let repetition_codeword = if self.n_oracles() != 0 {
let n_final_challenges = self.params.n_final_challenges();
let n_prior_challenges = self.fold_challenges.len() - n_final_challenges;
let final_challenges = &self.fold_challenges[n_prior_challenges..];
let mut scratch_buffer = vec![F::default(); 1 << n_final_challenges];
terminate_codeword
.chunks(1 << n_final_challenges)
.enumerate()
.map(|(i, coset_values)| {
fold_chunk(
self.params.rs_code(),
n_prior_challenges,
i,
coset_values,
final_challenges,
&mut scratch_buffer,
)
})
.collect::<Vec<_>>()
} else {
let fold_arity = self.params.rs_code().log_dim() + self.params.log_batch_size();
let mut scratch_buffer = vec![F::default(); 2 * (1 << fold_arity)];
terminate_codeword
.chunks(1 << fold_arity)
.enumerate()
.map(|(i, chunk)| {
fold_interleaved_chunk(
self.params.rs_code(),
self.params.log_batch_size(),
i,
chunk,
&self.interleave_tensor,
self.fold_challenges,
&mut scratch_buffer,
)
})
.collect::<Vec<_>>()
};
let final_value = repetition_codeword[0];
if repetition_codeword[1..]
.iter()
.any(|&entry| entry != final_value)
{
return Err(VerificationError::IncorrectDegree.into());
}
Ok(final_value)
}
pub fn verify_query<Advice: CanRead>(
&self,
index: usize,
terminate_codeword: &[F],
layers: &[Vec<VCS::Digest>],
advice: &mut Advice,
) -> Result<(), Error> {
self.verify_query_internal(
index,
terminate_codeword,
layers,
advice,
&mut self.create_scratch_buffer(),
)
}
#[instrument(skip_all, name = "fri::FRIVerifier::verify_query", level = "debug")]
fn verify_query_internal<Advice: CanRead>(
&self,
mut index: usize,
terminate_codeword: &[F],
layers: &[Vec<VCS::Digest>],
advice: &mut Advice,
scratch_buffer: &mut [F],
) -> Result<(), Error> {
let mut arities_iter = self.params.fold_arities().iter().copied();
let mut layer_digest_and_optimal_layer_depth =
iter::zip(layers, vcs_optimal_layers_depths_iter(self.params, self.vcs));
let Some(first_fold_arity) = arities_iter.next() else {
return Ok(());
};
let (first_layer, first_optimal_layer_depth) = layer_digest_and_optimal_layer_depth
.next()
.expect("The length should be the same as the amount of proofs.");
let mut fold_round = 0;
let mut log_n_cosets = self.params.index_bits();
let log_coset_size = first_fold_arity - self.params.log_batch_size();
let values = verify_coset_opening(
self.vcs,
index,
first_fold_arity,
first_optimal_layer_depth,
log_n_cosets,
first_layer,
advice,
)?;
let mut next_value = fold_interleaved_chunk(
self.params.rs_code(),
self.params.log_batch_size(),
index,
&values,
&self.interleave_tensor,
&self.fold_challenges[fold_round..fold_round + log_coset_size],
scratch_buffer,
);
fold_round += log_coset_size;
for (i, (arity, (layer, optimal_layer_depth))) in
izip!(arities_iter, layer_digest_and_optimal_layer_depth).enumerate()
{
let coset_index = index >> arity;
log_n_cosets -= arity;
let values = verify_coset_opening(
self.vcs,
coset_index,
arity,
optimal_layer_depth,
log_n_cosets,
layer,
advice,
)?;
if next_value != values[index % (1 << arity)] {
return Err(VerificationError::IncorrectFold {
query_round: i,
index,
}
.into());
}
next_value = fold_chunk(
self.params.rs_code(),
fold_round,
coset_index,
&values,
&self.fold_challenges[fold_round..fold_round + arity],
scratch_buffer,
);
index = coset_index;
fold_round += arity;
}
if next_value != terminate_codeword[index] {
return Err(VerificationError::IncorrectFold {
query_round: self.n_oracles() - 1,
index,
}
.into());
}
Ok(())
}
fn create_scratch_buffer(&self) -> Vec<F> {
let max_arity = self
.params
.fold_arities()
.iter()
.cloned()
.max()
.unwrap_or_default();
let max_buffer_size = 2 * (1 << max_arity);
vec![F::default(); max_buffer_size]
}
}
#[allow(clippy::too_many_arguments)]
fn verify_coset_opening<F, MTScheme, Advice>(
vcs: &MTScheme,
coset_index: usize,
log_coset_size: usize,
optimal_layer_depth: usize,
tree_depth: usize,
layer_digests: &[MTScheme::Digest],
advice: &mut Advice,
) -> Result<Vec<F>, Error>
where
F: TowerField,
MTScheme: MerkleTreeScheme<F>,
Advice: CanRead,
{
let values = advice.read_scalar_slice::<F>(1 << log_coset_size)?;
vcs.verify_opening(
coset_index,
&values,
optimal_layer_depth,
tree_depth,
layer_digests,
advice,
)
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
Ok(values)
}