binius_core/protocols/fri/
common.rs1use std::{iter, marker::PhantomData};
4
5use binius_field::{BinaryField, ExtensionField, PackedField};
6use binius_math::extrapolate_line_scalar;
7use binius_ntt::AdditiveNTT;
8use binius_utils::bail;
9use getset::{CopyGetters, Getters};
10
11use crate::{
12 merkle_tree::MerkleTreeScheme, protocols::fri::Error,
13 reed_solomon::reed_solomon::ReedSolomonCode,
14};
15
16#[inline]
22fn fold_pair<F, FS, NTT>(ntt: &NTT, round: usize, index: usize, values: (F, F), r: F) -> F
23where
24 F: BinaryField + ExtensionField<FS>,
25 FS: BinaryField,
26 NTT: AdditiveNTT<FS>,
27{
28 let t = ntt.get_subspace_eval(round, index);
30 let (mut u, mut v) = values;
31 v += u;
32 u += v * t;
33 extrapolate_line_scalar(u, v, r)
34}
35
36#[inline]
62pub fn fold_chunk<F, FS, NTT>(
63 ntt: &NTT,
64 mut log_len: usize,
65 chunk_index: usize,
66 values: &mut [F],
67 challenges: &[F],
68) -> F
69where
70 F: BinaryField + ExtensionField<FS>,
71 FS: BinaryField,
72 NTT: AdditiveNTT<FS>,
73{
74 let mut log_size = challenges.len();
75
76 debug_assert!(log_size <= log_len);
78 debug_assert!(log_len <= ntt.log_domain_size());
79 debug_assert_eq!(values.len(), 1 << log_size);
80
81 for &challenge in challenges {
83 let ntt_round = ntt.log_domain_size() - log_len;
85 for index_offset in 0..1 << (log_size - 1) {
86 let pair = (values[index_offset << 1], values[(index_offset << 1) | 1]);
87 values[index_offset] = fold_pair(
88 ntt,
89 ntt_round,
90 (chunk_index << (log_size - 1)) | index_offset,
91 pair,
92 challenge,
93 )
94 }
95
96 log_len -= 1;
97 log_size -= 1;
98 }
99
100 values[0]
101}
102
103#[inline]
126#[allow(clippy::too_many_arguments)]
127pub fn fold_interleaved_chunk<F, FS, P, NTT>(
128 ntt: &NTT,
129 log_len: usize,
130 log_batch_size: usize,
131 chunk_index: usize,
132 values: &[P],
133 tensor: &[P],
134 fold_challenges: &[F],
135 scratch_buffer: &mut [F],
136) -> F
137where
138 F: BinaryField + ExtensionField<FS>,
139 FS: BinaryField,
140 NTT: AdditiveNTT<FS>,
141 P: PackedField<Scalar = F>,
142{
143 debug_assert!(fold_challenges.len() <= log_len);
145 debug_assert!(log_len <= ntt.log_domain_size());
146 debug_assert_eq!(
147 values.len(),
148 1 << (fold_challenges.len() + log_batch_size).saturating_sub(P::LOG_WIDTH)
149 );
150 debug_assert_eq!(tensor.len(), 1 << log_batch_size.saturating_sub(P::LOG_WIDTH));
151 debug_assert!(scratch_buffer.len() >= 1 << fold_challenges.len());
152
153 let scratch_buffer = &mut scratch_buffer[..1 << fold_challenges.len()];
154
155 if log_batch_size == 0 {
156 iter::zip(&mut *scratch_buffer, P::iter_slice(values)).for_each(|(dst, val)| *dst = val);
157 } else {
158 let folded_values = values
159 .chunks(1 << (log_batch_size - P::LOG_WIDTH))
160 .map(|chunk| {
161 iter::zip(chunk, tensor)
162 .map(|(&a_i, &b_i)| a_i * b_i)
163 .sum::<P>()
164 .into_iter()
165 .take(1 << log_batch_size)
166 .sum()
167 });
168 iter::zip(&mut *scratch_buffer, folded_values).for_each(|(dst, val)| *dst = val);
169 };
170
171 fold_chunk(ntt, log_len, chunk_index, scratch_buffer, fold_challenges)
172}
173
174#[derive(Debug, Getters, CopyGetters)]
176pub struct FRIParams<F, FA>
177where
178 F: BinaryField,
179 FA: BinaryField,
180{
181 #[getset(get = "pub")]
183 rs_code: ReedSolomonCode<FA>,
184 #[getset(get_copy = "pub")]
186 log_batch_size: usize,
187 fold_arities: Vec<usize>,
189 #[getset(get_copy = "pub")]
191 n_test_queries: usize,
192 _marker: PhantomData<F>,
193}
194
195impl<F, FA> FRIParams<F, FA>
196where
197 F: BinaryField + ExtensionField<FA>,
198 FA: BinaryField,
199{
200 pub fn new(
201 rs_code: ReedSolomonCode<FA>,
202 log_batch_size: usize,
203 fold_arities: Vec<usize>,
204 n_test_queries: usize,
205 ) -> Result<Self, Error> {
206 if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
207 bail!(Error::InvalidFoldAritySequence)
208 }
209
210 Ok(Self {
211 rs_code,
212 log_batch_size,
213 fold_arities,
214 n_test_queries,
215 _marker: PhantomData,
216 })
217 }
218
219 pub const fn n_fold_rounds(&self) -> usize {
220 self.rs_code.log_dim() + self.log_batch_size
221 }
222
223 pub fn n_oracles(&self) -> usize {
225 self.fold_arities.len()
226 }
227
228 pub fn index_bits(&self) -> usize {
230 self.fold_arities
231 .first()
232 .map(|arity| self.log_len() - arity)
233 .unwrap_or(0)
235 }
236
237 pub fn n_final_challenges(&self) -> usize {
239 self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
240 }
241
242 pub fn fold_arities(&self) -> &[usize] {
244 &self.fold_arities
245 }
246
247 pub fn log_len(&self) -> usize {
249 self.rs_code().log_len() + self.log_batch_size()
250 }
251}
252
253pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
255 fri_params: &'a FRIParams<F, FA>,
256 vcs: &'a VCS,
257) -> impl Iterator<Item = usize> + 'a
258where
259 VCS: MerkleTreeScheme<F>,
260 F: BinaryField + ExtensionField<FA>,
261 FA: BinaryField,
262{
263 fri_params
264 .fold_arities()
265 .iter()
266 .scan(fri_params.log_len(), |log_n_cosets, arity| {
267 *log_n_cosets -= arity;
268 Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
269 })
270}
271
272pub type TerminateCodeword<F> = Vec<F>;
274
275pub fn calculate_n_test_queries<F, FEncode>(
280 security_bits: usize,
281 code: &ReedSolomonCode<FEncode>,
282) -> Result<usize, Error>
283where
284 F: BinaryField + ExtensionField<FEncode>,
285 FEncode: BinaryField,
286{
287 let field_size = 2.0_f64.powi(F::N_BITS as i32);
288 let sumcheck_err = (2 * code.log_dim()) as f64 / field_size;
289 let folding_err = code.len() as f64 / field_size;
291 let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
293 let allowed_query_err = 2.0_f64.powi(-(security_bits as i32)) - sumcheck_err - folding_err;
294 if allowed_query_err <= 0.0 {
295 return Err(Error::ParameterError);
296 }
297 let n_queries = allowed_query_err.log(per_query_err).ceil() as usize;
298 Ok(n_queries)
299}
300
301pub fn estimate_optimal_arity(
305 log_block_length: usize,
306 digest_size: usize,
307 field_size: usize,
308) -> usize {
309 (1..=log_block_length)
310 .map(|arity| {
311 (
312 arity,
316 ((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
317 * (log_block_length - arity)
318 / arity,
319 )
320 })
321 .scan(None, |old: &mut Option<(usize, usize)>, new| {
323 let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
324 *old = Some(new);
325 should_continue.then_some(new)
326 })
327 .last()
328 .map(|(arity, _)| arity)
329 .unwrap_or(1)
330}
331
332#[cfg(test)]
333mod tests {
334 use assert_matches::assert_matches;
335 use binius_field::{BinaryField128b, BinaryField32b};
336
337 use super::*;
338
339 #[test]
340 fn test_calculate_n_test_queries() {
341 let security_bits = 96;
342 let rs_code = ReedSolomonCode::new(28, 1).unwrap();
343 let n_test_queries =
344 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
345 .unwrap();
346 assert_eq!(n_test_queries, 232);
347
348 let rs_code = ReedSolomonCode::new(28, 2).unwrap();
349 let n_test_queries =
350 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
351 .unwrap();
352 assert_eq!(n_test_queries, 143);
353 }
354
355 #[test]
356 fn test_calculate_n_test_queries_unsatisfiable() {
357 let security_bits = 128;
358 let rs_code = ReedSolomonCode::<BinaryField32b>::new(28, 1).unwrap();
359 assert_matches!(
360 calculate_n_test_queries::<BinaryField128b, _>(security_bits, &rs_code),
361 Err(Error::ParameterError)
362 );
363 }
364
365 #[test]
366 fn test_estimate_optimal_arity() {
367 let field_size = 128;
368 for log_block_length in 22..35 {
369 let digest_size = 256;
370 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
371 }
372
373 for log_block_length in 22..28 {
374 let digest_size = 1024;
375 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
376 }
377 }
378}