binius_core/protocols/fri/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{util::inner_product_unchecked, BinaryField, ExtensionField, PackedField};
6use binius_math::extrapolate_line_scalar;
7use binius_ntt::AdditiveNTT;
8use binius_utils::bail;
9use getset::{CopyGetters, Getters};
10
11use crate::{
12	merkle_tree::MerkleTreeScheme, protocols::fri::Error,
13	reed_solomon::reed_solomon::ReedSolomonCode,
14};
15
16/// Calculate fold of `values` at `index` with `r` random coefficient.
17///
18/// See [DP24], Def. 3.6.
19///
20/// [DP24]: <https://eprint.iacr.org/2024/504>
21#[inline]
22fn fold_pair<F, FS>(
23	rs_code: &ReedSolomonCode<FS>,
24	round: usize,
25	index: usize,
26	values: (F, F),
27	r: F,
28) -> F
29where
30	F: BinaryField + ExtensionField<FS>,
31	FS: BinaryField,
32{
33	// Perform inverse additive NTT butterfly
34	let t = rs_code.get_ntt().get_subspace_eval(round, index);
35	let (mut u, mut v) = values;
36	v += u;
37	u += v * t;
38	extrapolate_line_scalar(u, v, r)
39}
40
41/// Calculate FRI fold of `values` at a `chunk_index` with random folding challenges.
42///
43/// REQUIRES:
44/// - `folding_challenges` is not empty.
45/// - `values.len() == 1 << folding_challenges.len()`.
46/// - `scratch_buffer.len() == values.len()`.
47/// - `start_round + folding_challenges.len() - 1 < rs_code.log_dim()`.
48///
49/// NB: This method is on a hot path and does not perform any allocations or
50/// precondition checks.
51///
52/// See [DP24], Def. 3.6 and Lemma 3.9 for more details.
53///
54/// [DP24]: <https://eprint.iacr.org/2024/504>
55#[inline]
56pub fn fold_chunk<F, FS>(
57	rs_code: &ReedSolomonCode<FS>,
58	start_round: usize,
59	chunk_index: usize,
60	values: &[F],
61	folding_challenges: &[F],
62	scratch_buffer: &mut [F],
63) -> F
64where
65	F: BinaryField + ExtensionField<FS>,
66	FS: BinaryField,
67{
68	// Preconditions
69	debug_assert!(!folding_challenges.is_empty());
70	debug_assert!(start_round + folding_challenges.len() <= rs_code.log_dim());
71	debug_assert_eq!(values.len(), 1 << folding_challenges.len());
72	debug_assert!(scratch_buffer.len() >= values.len());
73
74	// Fold the chunk with the folding challenges one by one
75	for n_challenges_processed in 0..folding_challenges.len() {
76		let n_remaining_challenges = folding_challenges.len() - n_challenges_processed;
77		let scratch_buffer_len = values.len() >> n_challenges_processed;
78		let new_scratch_buffer_len = scratch_buffer_len >> 1;
79		let round = start_round + n_challenges_processed;
80		let r = folding_challenges[n_challenges_processed];
81		let index_start = chunk_index << (n_remaining_challenges - 1);
82
83		// Fold the (2i) and (2i+1)th cells of the scratch buffer in-place into the i-th cell
84		if n_challenges_processed > 0 {
85			(0..new_scratch_buffer_len).for_each(|index_offset| {
86				let values =
87					(scratch_buffer[index_offset << 1], scratch_buffer[(index_offset << 1) + 1]);
88				scratch_buffer[index_offset] =
89					fold_pair(rs_code, round, index_start + index_offset, values, r)
90			});
91		} else {
92			// For the first round, we read values directly from the `values` slice.
93			(0..new_scratch_buffer_len).for_each(|index_offset| {
94				let values = (values[index_offset << 1], values[(index_offset << 1) + 1]);
95				scratch_buffer[index_offset] =
96					fold_pair(rs_code, round, index_start + index_offset, values, r)
97			});
98		}
99	}
100
101	scratch_buffer[0]
102}
103
104/// Calculate the fold of an interleaved chunk of values with random folding challenges.
105///
106/// The elements in the `values` vector are the interleaved cosets of a batch of codewords at the
107/// index `coset_index`. That is, the layout of elements in the values slice is
108///
109/// ```text
110/// [a0, b0, c0, d0, a1, b1, c1, d1, ...]
111/// ```
112///
113/// where `a0, a1, ...` form a coset of a codeword `a`, `b0, b1, ...` form a coset of a codeword
114/// `b`, and similarly for `c` and `d`.
115///
116/// The fold operation first folds the adjacent symbols in the slice using regular multilinear
117/// tensor folding for the symbols from different cosets and FRI folding for the cosets themselves
118/// using the remaining challenges.
119//
120/// NB: This method is on a hot path and does not perform any allocations or
121/// precondition checks.
122///
123/// See [DP24], Def. 3.6 and Lemma 3.9 for more details.
124///
125/// [DP24]: <https://eprint.iacr.org/2024/504>
126#[inline]
127pub fn fold_interleaved_chunk<F, FS>(
128	rs_code: &ReedSolomonCode<FS>,
129	log_batch_size: usize,
130	chunk_index: usize,
131	values: &[F],
132	tensor: &[F],
133	fold_challenges: &[F],
134	scratch_buffer: &mut [F],
135) -> F
136where
137	F: BinaryField + ExtensionField<FS>,
138	FS: BinaryField,
139{
140	// Preconditions
141	debug_assert!(fold_challenges.len() <= rs_code.log_dim());
142	debug_assert_eq!(values.len(), 1 << (log_batch_size + fold_challenges.len()));
143	debug_assert_eq!(tensor.len(), 1 << log_batch_size);
144	debug_assert!(scratch_buffer.len() >= 2 * (values.len() >> log_batch_size));
145
146	// There are two types of mixing we do in this loop. Buffer 1 is populated with the
147	// folding of symbols from the interleaved codewords into a single codeword. These
148	// values are mixed as a regular tensor product combination. Buffer 2 is then
149	// populated with `fold_chunk`, which folds a coset of a codeword using the FRI
150	// folding algorithm.
151	let (buffer1, buffer2) = scratch_buffer.split_at_mut(1 << fold_challenges.len());
152
153	for (interleave_chunk, val) in values.chunks(1 << log_batch_size).zip(buffer1.iter_mut()) {
154		*val = inner_product_unchecked(interleave_chunk.iter().copied(), tensor.iter().copied());
155	}
156
157	if fold_challenges.is_empty() {
158		buffer1[0]
159	} else {
160		fold_chunk(rs_code, 0, chunk_index, buffer1, fold_challenges, buffer2)
161	}
162}
163
164/// Parameters for an FRI interleaved code proximity protocol.
165#[derive(Debug, Getters, CopyGetters)]
166pub struct FRIParams<F, FA>
167where
168	F: BinaryField,
169	FA: BinaryField,
170{
171	/// The Reed-Solomon code the verifier is testing proximity to.
172	#[getset(get = "pub")]
173	rs_code: ReedSolomonCode<FA>,
174	/// Vector commitment scheme for the codeword oracle.
175	#[getset(get_copy = "pub")]
176	log_batch_size: usize,
177	/// The reduction arities between each oracle sent to the verifier.
178	fold_arities: Vec<usize>,
179	/// The number oracle consistency queries required during the query phase.
180	#[getset(get_copy = "pub")]
181	n_test_queries: usize,
182	_marker: PhantomData<F>,
183}
184
185impl<F, FA> FRIParams<F, FA>
186where
187	F: BinaryField + ExtensionField<FA>,
188	FA: BinaryField,
189{
190	pub fn new(
191		rs_code: ReedSolomonCode<FA>,
192		log_batch_size: usize,
193		fold_arities: Vec<usize>,
194		n_test_queries: usize,
195	) -> Result<Self, Error> {
196		if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
197			bail!(Error::InvalidFoldAritySequence)
198		}
199
200		Ok(Self {
201			rs_code,
202			log_batch_size,
203			fold_arities,
204			n_test_queries,
205			_marker: PhantomData,
206		})
207	}
208
209	pub const fn n_fold_rounds(&self) -> usize {
210		self.rs_code.log_dim() + self.log_batch_size
211	}
212
213	/// Number of oracles sent during the fold rounds.
214	pub fn n_oracles(&self) -> usize {
215		self.fold_arities.len()
216	}
217
218	/// Number of bits in the query indices sampled during the query phase.
219	pub fn index_bits(&self) -> usize {
220		self.fold_arities
221			.first()
222			.map(|arity| self.log_len() - arity)
223			// If there is no folding, there are no random queries either
224			.unwrap_or(0)
225	}
226
227	/// Number of folding challenges the verifier sends after receiving the last oracle.
228	pub fn n_final_challenges(&self) -> usize {
229		self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
230	}
231
232	/// The reduction arities between each oracle sent to the verifier.
233	pub fn fold_arities(&self) -> &[usize] {
234		&self.fold_arities
235	}
236
237	/// The binary logarithm of the length of the initial oracle.
238	pub fn log_len(&self) -> usize {
239		self.rs_code().log_len() + self.log_batch_size()
240	}
241}
242
243/// This layer allows minimizing the proof size.
244pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
245	fri_params: &'a FRIParams<F, FA>,
246	vcs: &'a VCS,
247) -> impl Iterator<Item = usize> + 'a
248where
249	VCS: MerkleTreeScheme<F>,
250	F: BinaryField + ExtensionField<FA>,
251	FA: BinaryField,
252{
253	fri_params
254		.fold_arities()
255		.iter()
256		.scan(fri_params.log_len(), |log_n_cosets, arity| {
257			*log_n_cosets -= arity;
258			Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
259		})
260}
261
262/// The type of the termination round codeword in the FRI protocol.
263pub type TerminateCodeword<F> = Vec<F>;
264
265/// Calculates the number of test queries required to achieve a target security level.
266///
267/// Throws [`Error::ParameterError`] if the security level is unattainable given the code
268/// parameters.
269pub fn calculate_n_test_queries<F, PS>(
270	security_bits: usize,
271	code: &ReedSolomonCode<PS>,
272) -> Result<usize, Error>
273where
274	F: BinaryField + ExtensionField<PS::Scalar>,
275	PS: PackedField<Scalar: BinaryField>,
276{
277	let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
278	let mut n_queries = (-(security_bits as f64) / per_query_err.log2()).ceil() as usize;
279	for _ in 0..10 {
280		if calculate_error_bound::<F, _>(code, n_queries) >= security_bits {
281			return Ok(n_queries);
282		}
283		n_queries += 1;
284	}
285	Err(Error::ParameterError)
286}
287
288fn calculate_error_bound<F, PS>(code: &ReedSolomonCode<PS>, n_queries: usize) -> usize
289where
290	F: BinaryField + ExtensionField<PS::Scalar>,
291	PS: PackedField<Scalar: BinaryField>,
292{
293	let field_size = 2.0_f64.powi(F::N_BITS as i32);
294	// ℓ' / |T_{τ}|
295	let sumcheck_err = code.log_dim() as f64 / field_size;
296	// 2^{ℓ' + R} / |T_{τ}|
297	let folding_err = code.len() as f64 / field_size;
298	let per_query_err = 0.5 * (1.0 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
299	let query_err = per_query_err.powi(n_queries as i32);
300	let total_err = sumcheck_err + folding_err + query_err;
301	-total_err.log2() as usize
302}
303
304/// Heuristic for estimating the optimal FRI folding arity that minimizes proof size.
305///
306/// `log_block_length` is the binary logarithm of the  block length of the Reed–Solomon code.
307pub fn estimate_optimal_arity(
308	log_block_length: usize,
309	digest_size: usize,
310	field_size: usize,
311) -> usize {
312	(1..=log_block_length)
313		.map(|arity| {
314			(
315				// for given arity, return a tuple (arity, estimate of query_proof_size).
316				// this estimate is basd on the following approximation of a single query_proof_size, where $\vartheta$ is the arity:
317				// $\big((n-\vartheta) + (n-2\vartheta) + \ldots\big)\text{digest_size} + \frac{n-\vartheta}{\vartheta}2^{\vartheta}\text{field_size}.$
318				arity,
319				((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
320					* (log_block_length - arity)
321					/ arity,
322			)
323		})
324		// now scan and terminate the iterator when query_proof_size increases.
325		.scan(None, |old: &mut Option<(usize, usize)>, new| {
326			let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
327			*old = Some(new);
328			should_continue.then_some(new)
329		})
330		.last()
331		.map(|(arity, _)| arity)
332		.unwrap_or(1)
333}
334
335#[cfg(test)]
336mod tests {
337	use assert_matches::assert_matches;
338	use binius_field::{BinaryField128b, BinaryField32b};
339	use binius_ntt::NTTOptions;
340
341	use super::*;
342
343	#[test]
344	fn test_calculate_n_test_queries() {
345		let security_bits = 96;
346		let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap();
347		let n_test_queries =
348			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
349				.unwrap();
350		assert_eq!(n_test_queries, 232);
351
352		let rs_code = ReedSolomonCode::new(28, 2, &NTTOptions::default()).unwrap();
353		let n_test_queries =
354			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
355				.unwrap();
356		assert_eq!(n_test_queries, 143);
357	}
358
359	#[test]
360	fn test_calculate_n_test_queries_unsatisfiable() {
361		let security_bits = 128;
362		let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap();
363		assert_matches!(
364			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code),
365			Err(Error::ParameterError)
366		);
367	}
368
369	#[test]
370	fn test_estimate_optimal_arity() {
371		let field_size = 128;
372		for log_block_length in 22..35 {
373			let digest_size = 256;
374			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
375		}
376
377		for log_block_length in 22..28 {
378			let digest_size = 1024;
379			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
380		}
381	}
382}