binius_core/protocols/fri/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, marker::PhantomData};
4
5use binius_field::{BinaryField, ExtensionField, PackedField};
6use binius_math::extrapolate_line_scalar;
7use binius_ntt::AdditiveNTT;
8use binius_utils::bail;
9use getset::{CopyGetters, Getters};
10
11use crate::{
12	merkle_tree::MerkleTreeScheme, protocols::fri::Error,
13	reed_solomon::reed_solomon::ReedSolomonCode,
14};
15
16/// Calculate fold of `values` at `index` with `r` random coefficient.
17///
18/// See [DP24], Def. 3.6.
19///
20/// [DP24]: <https://eprint.iacr.org/2024/504>
21#[inline]
22fn fold_pair<F, FS, NTT>(ntt: &NTT, round: usize, index: usize, values: (F, F), r: F) -> F
23where
24	F: BinaryField + ExtensionField<FS>,
25	FS: BinaryField,
26	NTT: AdditiveNTT<FS>,
27{
28	// Perform inverse additive NTT butterfly
29	let t = ntt.get_subspace_eval(round, index);
30	let (mut u, mut v) = values;
31	v += u;
32	u += v * t;
33	extrapolate_line_scalar(u, v, r)
34}
35
36/// Calculate FRI fold of `values` at a `chunk_index` with random folding challenges.
37///
38/// Folds a coset of a Reed–Solomon codeword into a single value using the FRI folding algorithm.
39/// The coset has size $2^n$, where $n$ is the number of challenges.
40///
41/// See [DP24], Def. 3.6 and Lemma 3.9 for more details.
42///
43/// NB: This method is on a hot path and does not perform any allocations or
44/// precondition checks.
45///
46/// ## Arguments
47///
48/// * `ntt` - the NTT instance, used to look up the twiddle values.
49/// * `log_len` - the binary logarithm of the code length.
50/// * `chunk_index` - the index of the chunk, of size $2^n$, in the full codeword.
51/// * `values` - mutable slice of values to fold, modified in place.
52/// * `challenges` - the sequence of folding challenges, with length $n$.
53///
54/// ## Pre-conditions
55///
56/// - `challenges.len() <= log_len`.
57/// - `log_len <= ntt.log_domain_size()`, so that the NTT domain is large enough.
58/// - `values.len() == 1 << challenges.len()`.
59///
60/// [DP24]: <https://eprint.iacr.org/2024/504>
61#[inline]
62pub fn fold_chunk<F, FS, NTT>(
63	ntt: &NTT,
64	mut log_len: usize,
65	chunk_index: usize,
66	values: &mut [F],
67	challenges: &[F],
68) -> F
69where
70	F: BinaryField + ExtensionField<FS>,
71	FS: BinaryField,
72	NTT: AdditiveNTT<FS>,
73{
74	let mut log_size = challenges.len();
75
76	// Preconditions
77	debug_assert!(log_size <= log_len);
78	debug_assert!(log_len <= ntt.log_domain_size());
79	debug_assert_eq!(values.len(), 1 << log_size);
80
81	// FRI-fold the values in place.
82	for &challenge in challenges {
83		// Fold the (2i) and (2i+1)th cells of the scratch buffer in-place into the i-th cell
84		let ntt_round = ntt.log_domain_size() - log_len;
85		for index_offset in 0..1 << (log_size - 1) {
86			let pair = (values[index_offset << 1], values[(index_offset << 1) | 1]);
87			values[index_offset] = fold_pair(
88				ntt,
89				ntt_round,
90				(chunk_index << (log_size - 1)) | index_offset,
91				pair,
92				challenge,
93			)
94		}
95
96		log_len -= 1;
97		log_size -= 1;
98	}
99
100	values[0]
101}
102
103/// Calculate the fold of an interleaved chunk of values with random folding challenges.
104///
105/// The elements in the `values` vector are the interleaved cosets of a batch of codewords at the
106/// index `coset_index`. That is, the layout of elements in the values slice is
107///
108/// ```text
109/// [a0, b0, c0, d0, a1, b1, c1, d1, ...]
110/// ```
111///
112/// where `a0, a1, ...` form a coset of a codeword `a`, `b0, b1, ...` form a coset of a codeword
113/// `b`, and similarly for `c` and `d`.
114///
115/// The fold operation first folds the adjacent symbols in the slice using regular multilinear
116/// tensor folding for the symbols from different cosets and FRI folding for the cosets themselves
117/// using the remaining challenges.
118//
119/// NB: This method is on a hot path and does not perform any allocations or
120/// precondition checks.
121///
122/// See [DP24], Def. 3.6 and Lemma 3.9 for more details.
123///
124/// [DP24]: <https://eprint.iacr.org/2024/504>
125#[inline]
126#[allow(clippy::too_many_arguments)]
127pub fn fold_interleaved_chunk<F, FS, P, NTT>(
128	ntt: &NTT,
129	log_len: usize,
130	log_batch_size: usize,
131	chunk_index: usize,
132	values: &[P],
133	tensor: &[P],
134	fold_challenges: &[F],
135	scratch_buffer: &mut [F],
136) -> F
137where
138	F: BinaryField + ExtensionField<FS>,
139	FS: BinaryField,
140	NTT: AdditiveNTT<FS>,
141	P: PackedField<Scalar = F>,
142{
143	// Preconditions
144	debug_assert!(fold_challenges.len() <= log_len);
145	debug_assert!(log_len <= ntt.log_domain_size());
146	debug_assert_eq!(
147		values.len(),
148		1 << (fold_challenges.len() + log_batch_size).saturating_sub(P::LOG_WIDTH)
149	);
150	debug_assert_eq!(tensor.len(), 1 << log_batch_size.saturating_sub(P::LOG_WIDTH));
151	debug_assert!(scratch_buffer.len() >= 1 << fold_challenges.len());
152
153	let scratch_buffer = &mut scratch_buffer[..1 << fold_challenges.len()];
154
155	if log_batch_size == 0 {
156		iter::zip(&mut *scratch_buffer, P::iter_slice(values)).for_each(|(dst, val)| *dst = val);
157	} else {
158		let folded_values = values
159			.chunks(1 << (log_batch_size - P::LOG_WIDTH))
160			.map(|chunk| {
161				iter::zip(chunk, tensor)
162					.map(|(&a_i, &b_i)| a_i * b_i)
163					.sum::<P>()
164					.into_iter()
165					.take(1 << log_batch_size)
166					.sum()
167			});
168		iter::zip(&mut *scratch_buffer, folded_values).for_each(|(dst, val)| *dst = val);
169	};
170
171	fold_chunk(ntt, log_len, chunk_index, scratch_buffer, fold_challenges)
172}
173
174/// Parameters for an FRI interleaved code proximity protocol.
175#[derive(Debug, Getters, CopyGetters)]
176pub struct FRIParams<F, FA>
177where
178	F: BinaryField,
179	FA: BinaryField,
180{
181	/// The Reed-Solomon code the verifier is testing proximity to.
182	#[getset(get = "pub")]
183	rs_code: ReedSolomonCode<FA>,
184	/// Vector commitment scheme for the codeword oracle.
185	#[getset(get_copy = "pub")]
186	log_batch_size: usize,
187	/// The reduction arities between each oracle sent to the verifier.
188	fold_arities: Vec<usize>,
189	/// The number oracle consistency queries required during the query phase.
190	#[getset(get_copy = "pub")]
191	n_test_queries: usize,
192	_marker: PhantomData<F>,
193}
194
195impl<F, FA> FRIParams<F, FA>
196where
197	F: BinaryField + ExtensionField<FA>,
198	FA: BinaryField,
199{
200	pub fn new(
201		rs_code: ReedSolomonCode<FA>,
202		log_batch_size: usize,
203		fold_arities: Vec<usize>,
204		n_test_queries: usize,
205	) -> Result<Self, Error> {
206		if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
207			bail!(Error::InvalidFoldAritySequence)
208		}
209
210		Ok(Self {
211			rs_code,
212			log_batch_size,
213			fold_arities,
214			n_test_queries,
215			_marker: PhantomData,
216		})
217	}
218
219	pub const fn n_fold_rounds(&self) -> usize {
220		self.rs_code.log_dim() + self.log_batch_size
221	}
222
223	/// Number of oracles sent during the fold rounds.
224	pub fn n_oracles(&self) -> usize {
225		self.fold_arities.len()
226	}
227
228	/// Number of bits in the query indices sampled during the query phase.
229	pub fn index_bits(&self) -> usize {
230		self.fold_arities
231			.first()
232			.map(|arity| self.log_len() - arity)
233			// If there is no folding, there are no random queries either
234			.unwrap_or(0)
235	}
236
237	/// Number of folding challenges the verifier sends after receiving the last oracle.
238	pub fn n_final_challenges(&self) -> usize {
239		self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
240	}
241
242	/// The reduction arities between each oracle sent to the verifier.
243	pub fn fold_arities(&self) -> &[usize] {
244		&self.fold_arities
245	}
246
247	/// The binary logarithm of the length of the initial oracle.
248	pub fn log_len(&self) -> usize {
249		self.rs_code().log_len() + self.log_batch_size()
250	}
251}
252
253/// This layer allows minimizing the proof size.
254pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
255	fri_params: &'a FRIParams<F, FA>,
256	vcs: &'a VCS,
257) -> impl Iterator<Item = usize> + 'a
258where
259	VCS: MerkleTreeScheme<F>,
260	F: BinaryField + ExtensionField<FA>,
261	FA: BinaryField,
262{
263	fri_params
264		.fold_arities()
265		.iter()
266		.scan(fri_params.log_len(), |log_n_cosets, arity| {
267			*log_n_cosets -= arity;
268			Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
269		})
270}
271
272/// The type of the termination round codeword in the FRI protocol.
273pub type TerminateCodeword<F> = Vec<F>;
274
275/// Calculates the number of test queries required to achieve a target security level.
276///
277/// Throws [`Error::ParameterError`] if the security level is unattainable given the code
278/// parameters.
279pub fn calculate_n_test_queries<F, FEncode>(
280	security_bits: usize,
281	code: &ReedSolomonCode<FEncode>,
282) -> Result<usize, Error>
283where
284	F: BinaryField + ExtensionField<FEncode>,
285	FEncode: BinaryField,
286{
287	let field_size = 2.0_f64.powi(F::N_BITS as i32);
288	let sumcheck_err = (2 * code.log_dim()) as f64 / field_size;
289	// 2 ⋅ ℓ' / |T_{τ}|
290	let folding_err = code.len() as f64 / field_size;
291	// 2^{ℓ' + R} / |T_{τ}|
292	let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
293	let allowed_query_err = 2.0_f64.powi(-(security_bits as i32)) - sumcheck_err - folding_err;
294	if allowed_query_err <= 0.0 {
295		return Err(Error::ParameterError);
296	}
297	let n_queries = allowed_query_err.log(per_query_err).ceil() as usize;
298	Ok(n_queries)
299}
300
301/// Heuristic for estimating the optimal FRI folding arity that minimizes proof size.
302///
303/// `log_block_length` is the binary logarithm of the  block length of the Reed–Solomon code.
304pub fn estimate_optimal_arity(
305	log_block_length: usize,
306	digest_size: usize,
307	field_size: usize,
308) -> usize {
309	(1..=log_block_length)
310		.map(|arity| {
311			(
312				// for given arity, return a tuple (arity, estimate of query_proof_size).
313				// this estimate is basd on the following approximation of a single query_proof_size, where $\vartheta$ is the arity:
314				// $\big((n-\vartheta) + (n-2\vartheta) + \ldots\big)\text{digest_size} + \frac{n-\vartheta}{\vartheta}2^{\vartheta}\text{field_size}.$
315				arity,
316				((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
317					* (log_block_length - arity)
318					/ arity,
319			)
320		})
321		// now scan and terminate the iterator when query_proof_size increases.
322		.scan(None, |old: &mut Option<(usize, usize)>, new| {
323			let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
324			*old = Some(new);
325			should_continue.then_some(new)
326		})
327		.last()
328		.map(|(arity, _)| arity)
329		.unwrap_or(1)
330}
331
332#[cfg(test)]
333mod tests {
334	use assert_matches::assert_matches;
335	use binius_field::{BinaryField128b, BinaryField32b};
336
337	use super::*;
338
339	#[test]
340	fn test_calculate_n_test_queries() {
341		let security_bits = 96;
342		let rs_code = ReedSolomonCode::new(28, 1).unwrap();
343		let n_test_queries =
344			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
345				.unwrap();
346		assert_eq!(n_test_queries, 232);
347
348		let rs_code = ReedSolomonCode::new(28, 2).unwrap();
349		let n_test_queries =
350			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
351				.unwrap();
352		assert_eq!(n_test_queries, 143);
353	}
354
355	#[test]
356	fn test_calculate_n_test_queries_unsatisfiable() {
357		let security_bits = 128;
358		let rs_code = ReedSolomonCode::<BinaryField32b>::new(28, 1).unwrap();
359		assert_matches!(
360			calculate_n_test_queries::<BinaryField128b, _>(security_bits, &rs_code),
361			Err(Error::ParameterError)
362		);
363	}
364
365	#[test]
366	fn test_estimate_optimal_arity() {
367		let field_size = 128;
368		for log_block_length in 22..35 {
369			let digest_size = 256;
370			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
371		}
372
373		for log_block_length in 22..28 {
374			let digest_size = 1024;
375			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
376		}
377	}
378}