binius_core/protocols/fri/
common.rsuse crate::{
linear_code::LinearCode, merkle_tree_vcs::MerkleTreeScheme, protocols::fri::Error,
reed_solomon::reed_solomon::ReedSolomonCode,
};
use binius_field::{util::inner_product_unchecked, BinaryField, ExtensionField, PackedField};
use binius_math::extrapolate_line_scalar;
use binius_ntt::AdditiveNTT;
use binius_utils::bail;
use getset::{CopyGetters, Getters};
use std::marker::PhantomData;
#[inline]
fn fold_pair<F, FS>(
rs_code: &ReedSolomonCode<FS>,
round: usize,
index: usize,
values: (F, F),
r: F,
) -> F
where
F: BinaryField + ExtensionField<FS>,
FS: BinaryField,
{
let t = rs_code.get_ntt().get_subspace_eval(round, index);
let (mut u, mut v) = values;
v += u;
u += v * t;
extrapolate_line_scalar(u, v, r)
}
#[inline]
pub fn fold_chunk<F, FS>(
rs_code: &ReedSolomonCode<FS>,
start_round: usize,
chunk_index: usize,
values: &[F],
folding_challenges: &[F],
scratch_buffer: &mut [F],
) -> F
where
F: BinaryField + ExtensionField<FS>,
FS: BinaryField,
{
debug_assert!(!folding_challenges.is_empty());
debug_assert!(start_round + folding_challenges.len() <= rs_code.log_dim());
debug_assert_eq!(values.len(), 1 << folding_challenges.len());
debug_assert!(scratch_buffer.len() >= values.len());
for n_challenges_processed in 0..folding_challenges.len() {
let n_remaining_challenges = folding_challenges.len() - n_challenges_processed;
let scratch_buffer_len = values.len() >> n_challenges_processed;
let new_scratch_buffer_len = scratch_buffer_len >> 1;
let round = start_round + n_challenges_processed;
let r = folding_challenges[n_challenges_processed];
let index_start = chunk_index << (n_remaining_challenges - 1);
if n_challenges_processed > 0 {
(0..new_scratch_buffer_len).for_each(|index_offset| {
let values =
(scratch_buffer[index_offset << 1], scratch_buffer[(index_offset << 1) + 1]);
scratch_buffer[index_offset] =
fold_pair(rs_code, round, index_start + index_offset, values, r)
});
} else {
(0..new_scratch_buffer_len).for_each(|index_offset| {
let values = (values[index_offset << 1], values[(index_offset << 1) + 1]);
scratch_buffer[index_offset] =
fold_pair(rs_code, round, index_start + index_offset, values, r)
});
}
}
scratch_buffer[0]
}
#[inline]
pub fn fold_interleaved_chunk<F, FS>(
rs_code: &ReedSolomonCode<FS>,
log_batch_size: usize,
chunk_index: usize,
values: &[F],
tensor: &[F],
fold_challenges: &[F],
scratch_buffer: &mut [F],
) -> F
where
F: BinaryField + ExtensionField<FS>,
FS: BinaryField,
{
debug_assert!(fold_challenges.len() <= rs_code.log_dim());
debug_assert_eq!(values.len(), 1 << (log_batch_size + fold_challenges.len()));
debug_assert_eq!(tensor.len(), 1 << log_batch_size);
debug_assert!(scratch_buffer.len() >= 2 * (values.len() >> log_batch_size));
let (buffer1, buffer2) = scratch_buffer.split_at_mut(1 << fold_challenges.len());
for (interleave_chunk, val) in values.chunks(1 << log_batch_size).zip(buffer1.iter_mut()) {
*val = inner_product_unchecked(interleave_chunk.iter().copied(), tensor.iter().copied());
}
if fold_challenges.is_empty() {
buffer1[0]
} else {
fold_chunk(rs_code, 0, chunk_index, buffer1, fold_challenges, buffer2)
}
}
#[derive(Debug, Getters, CopyGetters)]
pub struct FRIParams<F, FA>
where
F: BinaryField,
FA: BinaryField,
{
#[getset(get = "pub")]
rs_code: ReedSolomonCode<FA>,
#[getset(get_copy = "pub")]
log_batch_size: usize,
fold_arities: Vec<usize>,
#[getset(get_copy = "pub")]
n_test_queries: usize,
_marker: PhantomData<F>,
}
impl<F, FA> FRIParams<F, FA>
where
F: BinaryField + ExtensionField<FA>,
FA: BinaryField,
{
pub fn new(
rs_code: ReedSolomonCode<FA>,
log_batch_size: usize,
fold_arities: Vec<usize>,
n_test_queries: usize,
) -> Result<Self, Error> {
if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
bail!(Error::InvalidFoldAritySequence)
}
Ok(Self {
rs_code,
log_batch_size,
fold_arities,
n_test_queries,
_marker: PhantomData,
})
}
pub fn n_fold_rounds(&self) -> usize {
self.rs_code.log_dim() + self.log_batch_size
}
pub fn n_oracles(&self) -> usize {
self.fold_arities.len()
}
pub fn index_bits(&self) -> usize {
self.fold_arities
.first()
.map(|arity| self.log_len() - arity)
.unwrap_or(0)
}
pub fn n_final_challenges(&self) -> usize {
self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
}
pub fn fold_arities(&self) -> &[usize] {
&self.fold_arities
}
pub fn log_len(&self) -> usize {
self.rs_code().log_len() + self.log_batch_size()
}
}
pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
fri_params: &'a FRIParams<F, FA>,
vcs: &'a VCS,
) -> impl Iterator<Item = usize> + 'a
where
VCS: MerkleTreeScheme<F>,
F: BinaryField + ExtensionField<FA>,
FA: BinaryField,
{
fri_params
.fold_arities()
.iter()
.scan(fri_params.log_len(), |log_n_cosets, arity| {
*log_n_cosets -= arity;
Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
})
}
pub type QueryProof<F, VCSProof> = Vec<QueryRoundProof<F, VCSProof>>;
pub type TerminateCodeword<F> = Vec<F>;
#[derive(Debug, Clone)]
pub struct FRIProof<F, VCS: MerkleTreeScheme<F>> {
pub terminate_codeword: TerminateCodeword<F>,
pub proofs: Vec<QueryProof<F, VCS::Proof>>,
pub layers: Vec<Vec<VCS::Digest>>,
}
#[derive(Debug, Clone)]
pub struct QueryRoundProof<F, VCSProof> {
pub values: Vec<F>,
pub vcs_proof: VCSProof,
}
pub fn calculate_n_test_queries<F, PS>(
security_bits: usize,
code: &ReedSolomonCode<PS>,
) -> Result<usize, Error>
where
F: BinaryField + ExtensionField<PS::Scalar>,
PS: PackedField<Scalar: BinaryField>,
{
let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
let mut n_queries = (-(security_bits as f64) / per_query_err.log2()).ceil() as usize;
for _ in 0..10 {
if calculate_error_bound::<F, _>(code, n_queries) >= security_bits {
return Ok(n_queries);
}
n_queries += 1;
}
Err(Error::ParameterError)
}
fn calculate_error_bound<F, PS>(code: &ReedSolomonCode<PS>, n_queries: usize) -> usize
where
F: BinaryField + ExtensionField<PS::Scalar>,
PS: PackedField<Scalar: BinaryField>,
{
let field_size = 2.0_f64.powi(F::N_BITS as i32);
let sumcheck_err = code.log_dim() as f64 / field_size;
let folding_err = code.len() as f64 / field_size;
let per_query_err = 0.5 * (1.0 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
let query_err = per_query_err.powi(n_queries as i32);
let total_err = sumcheck_err + folding_err + query_err;
-total_err.log2() as usize
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use binius_field::{BinaryField128b, BinaryField32b};
use binius_ntt::NTTOptions;
#[test]
fn test_calculate_n_test_queries() {
let security_bits = 96;
let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap();
let n_test_queries =
calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
.unwrap();
assert_eq!(n_test_queries, 232);
let rs_code = ReedSolomonCode::new(28, 2, NTTOptions::default()).unwrap();
let n_test_queries =
calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
.unwrap();
assert_eq!(n_test_queries, 143);
}
#[test]
fn test_calculate_n_test_queries_unsatisfiable() {
let security_bits = 128;
let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap();
assert_matches!(
calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code),
Err(Error::ParameterError)
);
}
}