binius_core/protocols/fri/
common.rs1use std::marker::PhantomData;
4
5use binius_field::{util::inner_product_unchecked, BinaryField, ExtensionField, PackedField};
6use binius_math::extrapolate_line_scalar;
7use binius_ntt::AdditiveNTT;
8use binius_utils::bail;
9use getset::{CopyGetters, Getters};
10
11use crate::{
12 merkle_tree::MerkleTreeScheme, protocols::fri::Error,
13 reed_solomon::reed_solomon::ReedSolomonCode,
14};
15
16#[inline]
22fn fold_pair<F, FS>(
23 rs_code: &ReedSolomonCode<FS>,
24 round: usize,
25 index: usize,
26 values: (F, F),
27 r: F,
28) -> F
29where
30 F: BinaryField + ExtensionField<FS>,
31 FS: BinaryField,
32{
33 let t = rs_code.get_ntt().get_subspace_eval(round, index);
35 let (mut u, mut v) = values;
36 v += u;
37 u += v * t;
38 extrapolate_line_scalar(u, v, r)
39}
40
41#[inline]
56pub fn fold_chunk<F, FS>(
57 rs_code: &ReedSolomonCode<FS>,
58 start_round: usize,
59 chunk_index: usize,
60 values: &[F],
61 folding_challenges: &[F],
62 scratch_buffer: &mut [F],
63) -> F
64where
65 F: BinaryField + ExtensionField<FS>,
66 FS: BinaryField,
67{
68 debug_assert!(!folding_challenges.is_empty());
70 debug_assert!(start_round + folding_challenges.len() <= rs_code.log_dim());
71 debug_assert_eq!(values.len(), 1 << folding_challenges.len());
72 debug_assert!(scratch_buffer.len() >= values.len());
73
74 for n_challenges_processed in 0..folding_challenges.len() {
76 let n_remaining_challenges = folding_challenges.len() - n_challenges_processed;
77 let scratch_buffer_len = values.len() >> n_challenges_processed;
78 let new_scratch_buffer_len = scratch_buffer_len >> 1;
79 let round = start_round + n_challenges_processed;
80 let r = folding_challenges[n_challenges_processed];
81 let index_start = chunk_index << (n_remaining_challenges - 1);
82
83 if n_challenges_processed > 0 {
85 (0..new_scratch_buffer_len).for_each(|index_offset| {
86 let values =
87 (scratch_buffer[index_offset << 1], scratch_buffer[(index_offset << 1) + 1]);
88 scratch_buffer[index_offset] =
89 fold_pair(rs_code, round, index_start + index_offset, values, r)
90 });
91 } else {
92 (0..new_scratch_buffer_len).for_each(|index_offset| {
94 let values = (values[index_offset << 1], values[(index_offset << 1) + 1]);
95 scratch_buffer[index_offset] =
96 fold_pair(rs_code, round, index_start + index_offset, values, r)
97 });
98 }
99 }
100
101 scratch_buffer[0]
102}
103
104#[inline]
127pub fn fold_interleaved_chunk<F, FS>(
128 rs_code: &ReedSolomonCode<FS>,
129 log_batch_size: usize,
130 chunk_index: usize,
131 values: &[F],
132 tensor: &[F],
133 fold_challenges: &[F],
134 scratch_buffer: &mut [F],
135) -> F
136where
137 F: BinaryField + ExtensionField<FS>,
138 FS: BinaryField,
139{
140 debug_assert!(fold_challenges.len() <= rs_code.log_dim());
142 debug_assert_eq!(values.len(), 1 << (log_batch_size + fold_challenges.len()));
143 debug_assert_eq!(tensor.len(), 1 << log_batch_size);
144 debug_assert!(scratch_buffer.len() >= 2 * (values.len() >> log_batch_size));
145
146 let (buffer1, buffer2) = scratch_buffer.split_at_mut(1 << fold_challenges.len());
152
153 for (interleave_chunk, val) in values.chunks(1 << log_batch_size).zip(buffer1.iter_mut()) {
154 *val = inner_product_unchecked(interleave_chunk.iter().copied(), tensor.iter().copied());
155 }
156
157 if fold_challenges.is_empty() {
158 buffer1[0]
159 } else {
160 fold_chunk(rs_code, 0, chunk_index, buffer1, fold_challenges, buffer2)
161 }
162}
163
164#[derive(Debug, Getters, CopyGetters)]
166pub struct FRIParams<F, FA>
167where
168 F: BinaryField,
169 FA: BinaryField,
170{
171 #[getset(get = "pub")]
173 rs_code: ReedSolomonCode<FA>,
174 #[getset(get_copy = "pub")]
176 log_batch_size: usize,
177 fold_arities: Vec<usize>,
179 #[getset(get_copy = "pub")]
181 n_test_queries: usize,
182 _marker: PhantomData<F>,
183}
184
185impl<F, FA> FRIParams<F, FA>
186where
187 F: BinaryField + ExtensionField<FA>,
188 FA: BinaryField,
189{
190 pub fn new(
191 rs_code: ReedSolomonCode<FA>,
192 log_batch_size: usize,
193 fold_arities: Vec<usize>,
194 n_test_queries: usize,
195 ) -> Result<Self, Error> {
196 if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
197 bail!(Error::InvalidFoldAritySequence)
198 }
199
200 Ok(Self {
201 rs_code,
202 log_batch_size,
203 fold_arities,
204 n_test_queries,
205 _marker: PhantomData,
206 })
207 }
208
209 pub const fn n_fold_rounds(&self) -> usize {
210 self.rs_code.log_dim() + self.log_batch_size
211 }
212
213 pub fn n_oracles(&self) -> usize {
215 self.fold_arities.len()
216 }
217
218 pub fn index_bits(&self) -> usize {
220 self.fold_arities
221 .first()
222 .map(|arity| self.log_len() - arity)
223 .unwrap_or(0)
225 }
226
227 pub fn n_final_challenges(&self) -> usize {
229 self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
230 }
231
232 pub fn fold_arities(&self) -> &[usize] {
234 &self.fold_arities
235 }
236
237 pub fn log_len(&self) -> usize {
239 self.rs_code().log_len() + self.log_batch_size()
240 }
241}
242
243pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
245 fri_params: &'a FRIParams<F, FA>,
246 vcs: &'a VCS,
247) -> impl Iterator<Item = usize> + 'a
248where
249 VCS: MerkleTreeScheme<F>,
250 F: BinaryField + ExtensionField<FA>,
251 FA: BinaryField,
252{
253 fri_params
254 .fold_arities()
255 .iter()
256 .scan(fri_params.log_len(), |log_n_cosets, arity| {
257 *log_n_cosets -= arity;
258 Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
259 })
260}
261
262pub type TerminateCodeword<F> = Vec<F>;
264
265pub fn calculate_n_test_queries<F, PS>(
270 security_bits: usize,
271 code: &ReedSolomonCode<PS>,
272) -> Result<usize, Error>
273where
274 F: BinaryField + ExtensionField<PS::Scalar>,
275 PS: PackedField<Scalar: BinaryField>,
276{
277 let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
278 let mut n_queries = (-(security_bits as f64) / per_query_err.log2()).ceil() as usize;
279 for _ in 0..10 {
280 if calculate_error_bound::<F, _>(code, n_queries) >= security_bits {
281 return Ok(n_queries);
282 }
283 n_queries += 1;
284 }
285 Err(Error::ParameterError)
286}
287
288fn calculate_error_bound<F, PS>(code: &ReedSolomonCode<PS>, n_queries: usize) -> usize
289where
290 F: BinaryField + ExtensionField<PS::Scalar>,
291 PS: PackedField<Scalar: BinaryField>,
292{
293 let field_size = 2.0_f64.powi(F::N_BITS as i32);
294 let sumcheck_err = code.log_dim() as f64 / field_size;
296 let folding_err = code.len() as f64 / field_size;
298 let per_query_err = 0.5 * (1.0 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
299 let query_err = per_query_err.powi(n_queries as i32);
300 let total_err = sumcheck_err + folding_err + query_err;
301 -total_err.log2() as usize
302}
303
304pub fn estimate_optimal_arity(
308 log_block_length: usize,
309 digest_size: usize,
310 field_size: usize,
311) -> usize {
312 (1..=log_block_length)
313 .map(|arity| {
314 (
315 arity,
319 ((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
320 * (log_block_length - arity)
321 / arity,
322 )
323 })
324 .scan(None, |old: &mut Option<(usize, usize)>, new| {
326 let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
327 *old = Some(new);
328 should_continue.then_some(new)
329 })
330 .last()
331 .map(|(arity, _)| arity)
332 .unwrap_or(1)
333}
334
335#[cfg(test)]
336mod tests {
337 use assert_matches::assert_matches;
338 use binius_field::{BinaryField128b, BinaryField32b};
339 use binius_ntt::NTTOptions;
340
341 use super::*;
342
343 #[test]
344 fn test_calculate_n_test_queries() {
345 let security_bits = 96;
346 let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap();
347 let n_test_queries =
348 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
349 .unwrap();
350 assert_eq!(n_test_queries, 232);
351
352 let rs_code = ReedSolomonCode::new(28, 2, &NTTOptions::default()).unwrap();
353 let n_test_queries =
354 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
355 .unwrap();
356 assert_eq!(n_test_queries, 143);
357 }
358
359 #[test]
360 fn test_calculate_n_test_queries_unsatisfiable() {
361 let security_bits = 128;
362 let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap();
363 assert_matches!(
364 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code),
365 Err(Error::ParameterError)
366 );
367 }
368
369 #[test]
370 fn test_estimate_optimal_arity() {
371 let field_size = 128;
372 for log_block_length in 22..35 {
373 let digest_size = 256;
374 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
375 }
376
377 for log_block_length in 22..28 {
378 let digest_size = 1024;
379 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
380 }
381 }
382}