binius_core/protocols/fri/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{BinaryField, ExtensionField};
6use binius_ntt::AdditiveNTT;
7use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
8use getset::{CopyGetters, Getters};
9
10use crate::{
11	merkle_tree::MerkleTreeScheme, protocols::fri::Error,
12	reed_solomon::reed_solomon::ReedSolomonCode,
13};
14
15/// Parameters for an FRI interleaved code proximity protocol.
16#[derive(Debug, Getters, CopyGetters)]
17pub struct FRIParams<F, FA>
18where
19	F: BinaryField,
20	FA: BinaryField,
21{
22	/// The Reed-Solomon code the verifier is testing proximity to.
23	#[getset(get = "pub")]
24	rs_code: ReedSolomonCode<FA>,
25	/// Vector commitment scheme for the codeword oracle.
26	#[getset(get_copy = "pub")]
27	log_batch_size: usize,
28	/// The reduction arities between each oracle sent to the verifier.
29	fold_arities: Vec<usize>,
30	/// The number oracle consistency queries required during the query phase.
31	#[getset(get_copy = "pub")]
32	n_test_queries: usize,
33	_marker: PhantomData<F>,
34}
35
36impl<F, FA> FRIParams<F, FA>
37where
38	F: BinaryField + ExtensionField<FA>,
39	FA: BinaryField,
40{
41	pub fn new(
42		rs_code: ReedSolomonCode<FA>,
43		log_batch_size: usize,
44		fold_arities: Vec<usize>,
45		n_test_queries: usize,
46	) -> Result<Self, Error> {
47		if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
48			bail!(Error::InvalidFoldAritySequence)
49		}
50
51		Ok(Self {
52			rs_code,
53			log_batch_size,
54			fold_arities,
55			n_test_queries,
56			_marker: PhantomData,
57		})
58	}
59
60	/// Choose commit parameters based on protocol parameters, using a constant fold arity.
61	///
62	/// ## Arguments
63	///
64	/// * `log_msg_len` - the binary logarithm of the length of the message to commit.
65	/// * `security_bits` - the target security level in bits.
66	/// * `log_inv_rate` - the binary logarithm of the inverse Reed–Solomon code rate.
67	/// * `arity` - the folding arity.
68	pub fn choose_with_constant_fold_arity(
69		ntt: &impl AdditiveNTT<FA>,
70		log_msg_len: usize,
71		security_bits: usize,
72		log_inv_rate: usize,
73		arity: usize,
74	) -> Result<Self, Error> {
75		assert!(arity > 0);
76
77		let log_dim = log_msg_len.saturating_sub(arity);
78		let log_batch_size = log_msg_len.min(arity);
79		let rs_code = ReedSolomonCode::with_ntt_subspace(ntt, log_dim, log_inv_rate)?;
80		let n_test_queries = calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
81
82		let cap_height = log2_ceil_usize(n_test_queries);
83		let fold_arities = std::iter::repeat_n(
84			arity,
85			log_msg_len.saturating_sub(cap_height.saturating_sub(log_inv_rate)) / arity,
86		)
87		.collect::<Vec<_>>();
88		// here is the down-to-earth explanation of what we're doing: we want the terminal
89		// codeword's log-length to be at least as large as the Merkle cap height. note that
90		// `total_vars + log_inv_rate - sum(fold_arities)` is exactly the log-length of the
91		// terminal codeword; we want this number to be ≥ cap height. so fold_arities will repeat
92		// `arity` the maximal number of times possible, while maintaining that `total_vars +
93		// log_inv_rate - sum(fold_arities) ≥ cap_height` stays true. this arity-selection
94		// strategy can be characterized as: "terminate as late as you can, while maintaining that
95		// no Merkle cap is strictly smaller than `cap_height`." this strategy does attain that
96		// property: the Merkle path height of the last non-terminal codeword will equal the
97		// log-length of the terminal codeword, which is ≥ cap height by fiat. moreover, if we
98		// terminated later than we are above, then this would stop being true. imagine what would
99		// happen if we took the above terminal codeword and continued folding. in that case, we
100		// would Merklize this word, again with the coset-bundling trick; the post-bundling path
101		// height would thus be `total_vars + log_inv_rate - sum(fold_arities) - arity`. but we
102		// already agreed (by the maximality of the number of times we subtracted `arity`) that
103		// the above number will be < cap_height. in other words, its Merkle cap will be
104		// short. equivalently: this is the latest termination for which the `min` in
105		// `optimal_verify_layer` will never trigger; i.e., we will have log2_ceil_usize(n_queries)
106		// ≤ tree_depth there. it can be shown that this strategy beats any strategy which
107		// terminates later than it does (in other words, by doing this, we are NOT terminating
108		// TOO early!). this doesn't mean that we should't terminate EVEN earlier (maybe we
109		// should). but this approach is conservative and simple; and it's easy to show that you
110		// won't lose by doing this.
111
112		// see https://github.com/IrreducibleOSS/binius/pull/300 for proof of this fact
113
114		// how should we handle the case `fold_arities = []`, i.e. total_vars + log_inv_rate -
115		// cap_height < arity? in that case, we would lose nothing by making the entire thing
116		// interleaved, i.e., setting `log_batch_size := total_vars`, so `terminal_codeword` lives
117		// in the interleaving of the repetition code (and so is itself a repetition codeword!).
118		// encoding is trivial. but there's a circularity: whether `total_vars + log_inv_rate -
119		// cap_height < arity` or not depends on `cap_height`, which depends on `n_test_queries`,
120		// which depends on `log_dim`--- soundness depends on block length!---which finally itself
121		// depends on whether we're using the repetition code or not. of course this circular
122		// dependency is artificial, since in the case `log_batch_size = total_vars` and `log_dim
123		// = 0`, we're sending the entire message anyway, so the FRI portion is essentially
124		// trivial / superfluous, and the security is perfect. and in any case we could evade it
125		// simply by calculating `n_test_queries` and `cap_height` using the provisional `log_dim
126		// := total_vars.saturating_sub(arity)`, proceeding as above, and only then, if we find
127		// out post facto that `fold_arities = []`, overwriting `log_batch_size := total_vars` and
128		// `log_dim = 0`---and even recalculating `n_test_queries` if we wanted (though of course
129		// it doesn't matter---we could do 0 queries in that case, and we would still get
130		// security---and in fact during the actual querying part we will skip querying
131		// anyway). in any case, from a purely code-simplicity point of view, the simplest approach
132		// is to bite the bullet and let `log_batch_size := min(total_vars, arity)` for good---and
133		// keep it there, even if we post-facto find out that `fold_arities = []`. the cost of
134		// this is that the prover has to do a nontrivial (though small!) interleaved encoding, as
135		// opposed to a trivial one.
136		Self::new(rs_code, log_batch_size, fold_arities, n_test_queries)
137	}
138
139	pub const fn n_fold_rounds(&self) -> usize {
140		self.rs_code.log_dim() + self.log_batch_size
141	}
142
143	/// Number of oracles sent during the fold rounds.
144	pub fn n_oracles(&self) -> usize {
145		self.fold_arities.len()
146	}
147
148	/// Number of bits in the query indices sampled during the query phase.
149	pub fn index_bits(&self) -> usize {
150		self.fold_arities
151			.first()
152			.map(|arity| self.log_len() - arity)
153			// If there is no folding, there are no random queries either
154			.unwrap_or(0)
155	}
156
157	/// Number of folding challenges the verifier sends after receiving the last oracle.
158	pub fn n_final_challenges(&self) -> usize {
159		self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
160	}
161
162	/// The reduction arities between each oracle sent to the verifier.
163	pub fn fold_arities(&self) -> &[usize] {
164		&self.fold_arities
165	}
166
167	/// The binary logarithm of the length of the initial oracle.
168	pub fn log_len(&self) -> usize {
169		self.rs_code().log_len() + self.log_batch_size()
170	}
171}
172
173/// This layer allows minimizing the proof size.
174pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
175	fri_params: &'a FRIParams<F, FA>,
176	vcs: &'a VCS,
177) -> impl Iterator<Item = usize> + 'a
178where
179	VCS: MerkleTreeScheme<F>,
180	F: BinaryField + ExtensionField<FA>,
181	FA: BinaryField,
182{
183	fri_params
184		.fold_arities()
185		.iter()
186		.scan(fri_params.log_len(), |log_n_cosets, arity| {
187			*log_n_cosets -= arity;
188			Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
189		})
190}
191
192/// The type of the termination round codeword in the FRI protocol.
193pub type TerminateCodeword<F> = Vec<F>;
194
195/// Calculates the number of test queries required to achieve a target security level.
196///
197/// Throws [`Error::ParameterError`] if the security level is unattainable given the code
198/// parameters.
199pub fn calculate_n_test_queries<F, FEncode>(
200	security_bits: usize,
201	code: &ReedSolomonCode<FEncode>,
202) -> Result<usize, Error>
203where
204	F: BinaryField + ExtensionField<FEncode>,
205	FEncode: BinaryField,
206{
207	let field_size = 2.0_f64.powi(F::N_BITS as i32);
208	let sumcheck_err = (2 * code.log_dim()) as f64 / field_size;
209	// 2 ⋅ ℓ' / |T_{τ}|
210	let folding_err = code.len() as f64 / field_size;
211	// 2^{ℓ' + R} / |T_{τ}|
212	let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
213	let allowed_query_err = 2.0_f64.powi(-(security_bits as i32)) - sumcheck_err - folding_err;
214	if allowed_query_err <= 0.0 {
215		return Err(Error::ParameterError);
216	}
217	let n_queries = allowed_query_err.log(per_query_err).ceil() as usize;
218	Ok(n_queries)
219}
220
221/// Heuristic for estimating the optimal FRI folding arity that minimizes proof size.
222///
223/// `log_block_length` is the binary logarithm of the  block length of the Reed–Solomon code.
224pub fn estimate_optimal_arity(
225	log_block_length: usize,
226	digest_size: usize,
227	field_size: usize,
228) -> usize {
229	(1..=log_block_length)
230		.map(|arity| {
231			(
232				// for given arity, return a tuple (arity, estimate of query_proof_size).
233				// this estimate is basd on the following approximation of a single
234				// query_proof_size, where $\vartheta$ is the arity: $\big((n-\vartheta) +
235				// (n-2\vartheta) + \ldots\big)\text{digest_size} +
236				// \frac{n-\vartheta}{\vartheta}2^{\vartheta}\text{field_size}.$
237				arity,
238				((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
239					* (log_block_length - arity)
240					/ arity,
241			)
242		})
243		// now scan and terminate the iterator when query_proof_size increases.
244		.scan(None, |old: &mut Option<(usize, usize)>, new| {
245			let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
246			*old = Some(new);
247			should_continue.then_some(new)
248		})
249		.last()
250		.map(|(arity, _)| arity)
251		.unwrap_or(1)
252}
253
254#[cfg(test)]
255mod tests {
256	use assert_matches::assert_matches;
257	use binius_field::{BinaryField32b, BinaryField128b};
258
259	use super::*;
260
261	#[test]
262	fn test_calculate_n_test_queries() {
263		let security_bits = 96;
264		let rs_code = ReedSolomonCode::new(28, 1).unwrap();
265		let n_test_queries =
266			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
267				.unwrap();
268		assert_eq!(n_test_queries, 232);
269
270		let rs_code = ReedSolomonCode::new(28, 2).unwrap();
271		let n_test_queries =
272			calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
273				.unwrap();
274		assert_eq!(n_test_queries, 143);
275	}
276
277	#[test]
278	fn test_calculate_n_test_queries_unsatisfiable() {
279		let security_bits = 128;
280		let rs_code = ReedSolomonCode::<BinaryField32b>::new(28, 1).unwrap();
281		assert_matches!(
282			calculate_n_test_queries::<BinaryField128b, _>(security_bits, &rs_code),
283			Err(Error::ParameterError)
284		);
285	}
286
287	#[test]
288	fn test_estimate_optimal_arity() {
289		let field_size = 128;
290		for log_block_length in 22..35 {
291			let digest_size = 256;
292			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
293		}
294
295		for log_block_length in 22..28 {
296			let digest_size = 1024;
297			assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
298		}
299	}
300}