binius_core/protocols/fri/common.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{BinaryField, ExtensionField};
6use binius_ntt::AdditiveNTT;
7use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
8use getset::{CopyGetters, Getters};
9
10use crate::{
11 merkle_tree::MerkleTreeScheme, protocols::fri::Error,
12 reed_solomon::reed_solomon::ReedSolomonCode,
13};
14
15/// Parameters for an FRI interleaved code proximity protocol.
16#[derive(Debug, Getters, CopyGetters)]
17pub struct FRIParams<F, FA>
18where
19 F: BinaryField,
20 FA: BinaryField,
21{
22 /// The Reed-Solomon code the verifier is testing proximity to.
23 #[getset(get = "pub")]
24 rs_code: ReedSolomonCode<FA>,
25 /// Vector commitment scheme for the codeword oracle.
26 #[getset(get_copy = "pub")]
27 log_batch_size: usize,
28 /// The reduction arities between each oracle sent to the verifier.
29 fold_arities: Vec<usize>,
30 /// The number oracle consistency queries required during the query phase.
31 #[getset(get_copy = "pub")]
32 n_test_queries: usize,
33 _marker: PhantomData<F>,
34}
35
36impl<F, FA> FRIParams<F, FA>
37where
38 F: BinaryField + ExtensionField<FA>,
39 FA: BinaryField,
40{
41 pub fn new(
42 rs_code: ReedSolomonCode<FA>,
43 log_batch_size: usize,
44 fold_arities: Vec<usize>,
45 n_test_queries: usize,
46 ) -> Result<Self, Error> {
47 if fold_arities.iter().sum::<usize>() >= rs_code.log_dim() + log_batch_size {
48 bail!(Error::InvalidFoldAritySequence)
49 }
50
51 Ok(Self {
52 rs_code,
53 log_batch_size,
54 fold_arities,
55 n_test_queries,
56 _marker: PhantomData,
57 })
58 }
59
60 /// Choose commit parameters based on protocol parameters, using a constant fold arity.
61 ///
62 /// ## Arguments
63 ///
64 /// * `log_msg_len` - the binary logarithm of the length of the message to commit.
65 /// * `security_bits` - the target security level in bits.
66 /// * `log_inv_rate` - the binary logarithm of the inverse Reed–Solomon code rate.
67 /// * `arity` - the folding arity.
68 pub fn choose_with_constant_fold_arity(
69 ntt: &impl AdditiveNTT<FA>,
70 log_msg_len: usize,
71 security_bits: usize,
72 log_inv_rate: usize,
73 arity: usize,
74 ) -> Result<Self, Error> {
75 assert!(arity > 0);
76
77 let log_dim = log_msg_len.saturating_sub(arity);
78 let log_batch_size = log_msg_len.min(arity);
79 let rs_code = ReedSolomonCode::with_ntt_subspace(ntt, log_dim, log_inv_rate)?;
80 let n_test_queries = calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
81
82 let cap_height = log2_ceil_usize(n_test_queries);
83 let fold_arities = std::iter::repeat_n(
84 arity,
85 log_msg_len.saturating_sub(cap_height.saturating_sub(log_inv_rate)) / arity,
86 )
87 .collect::<Vec<_>>();
88 // here is the down-to-earth explanation of what we're doing: we want the terminal
89 // codeword's log-length to be at least as large as the Merkle cap height. note that
90 // `total_vars + log_inv_rate - sum(fold_arities)` is exactly the log-length of the
91 // terminal codeword; we want this number to be ≥ cap height. so fold_arities will repeat
92 // `arity` the maximal number of times possible, while maintaining that `total_vars +
93 // log_inv_rate - sum(fold_arities) ≥ cap_height` stays true. this arity-selection
94 // strategy can be characterized as: "terminate as late as you can, while maintaining that
95 // no Merkle cap is strictly smaller than `cap_height`." this strategy does attain that
96 // property: the Merkle path height of the last non-terminal codeword will equal the
97 // log-length of the terminal codeword, which is ≥ cap height by fiat. moreover, if we
98 // terminated later than we are above, then this would stop being true. imagine what would
99 // happen if we took the above terminal codeword and continued folding. in that case, we
100 // would Merklize this word, again with the coset-bundling trick; the post-bundling path
101 // height would thus be `total_vars + log_inv_rate - sum(fold_arities) - arity`. but we
102 // already agreed (by the maximality of the number of times we subtracted `arity`) that
103 // the above number will be < cap_height. in other words, its Merkle cap will be
104 // short. equivalently: this is the latest termination for which the `min` in
105 // `optimal_verify_layer` will never trigger; i.e., we will have log2_ceil_usize(n_queries)
106 // ≤ tree_depth there. it can be shown that this strategy beats any strategy which
107 // terminates later than it does (in other words, by doing this, we are NOT terminating
108 // TOO early!). this doesn't mean that we should't terminate EVEN earlier (maybe we
109 // should). but this approach is conservative and simple; and it's easy to show that you
110 // won't lose by doing this.
111
112 // see https://github.com/IrreducibleOSS/binius/pull/300 for proof of this fact
113
114 // how should we handle the case `fold_arities = []`, i.e. total_vars + log_inv_rate -
115 // cap_height < arity? in that case, we would lose nothing by making the entire thing
116 // interleaved, i.e., setting `log_batch_size := total_vars`, so `terminal_codeword` lives
117 // in the interleaving of the repetition code (and so is itself a repetition codeword!).
118 // encoding is trivial. but there's a circularity: whether `total_vars + log_inv_rate -
119 // cap_height < arity` or not depends on `cap_height`, which depends on `n_test_queries`,
120 // which depends on `log_dim`--- soundness depends on block length!---which finally itself
121 // depends on whether we're using the repetition code or not. of course this circular
122 // dependency is artificial, since in the case `log_batch_size = total_vars` and `log_dim
123 // = 0`, we're sending the entire message anyway, so the FRI portion is essentially
124 // trivial / superfluous, and the security is perfect. and in any case we could evade it
125 // simply by calculating `n_test_queries` and `cap_height` using the provisional `log_dim
126 // := total_vars.saturating_sub(arity)`, proceeding as above, and only then, if we find
127 // out post facto that `fold_arities = []`, overwriting `log_batch_size := total_vars` and
128 // `log_dim = 0`---and even recalculating `n_test_queries` if we wanted (though of course
129 // it doesn't matter---we could do 0 queries in that case, and we would still get
130 // security---and in fact during the actual querying part we will skip querying
131 // anyway). in any case, from a purely code-simplicity point of view, the simplest approach
132 // is to bite the bullet and let `log_batch_size := min(total_vars, arity)` for good---and
133 // keep it there, even if we post-facto find out that `fold_arities = []`. the cost of
134 // this is that the prover has to do a nontrivial (though small!) interleaved encoding, as
135 // opposed to a trivial one.
136 Self::new(rs_code, log_batch_size, fold_arities, n_test_queries)
137 }
138
139 pub const fn n_fold_rounds(&self) -> usize {
140 self.rs_code.log_dim() + self.log_batch_size
141 }
142
143 /// Number of oracles sent during the fold rounds.
144 pub fn n_oracles(&self) -> usize {
145 self.fold_arities.len()
146 }
147
148 /// Number of bits in the query indices sampled during the query phase.
149 pub fn index_bits(&self) -> usize {
150 self.fold_arities
151 .first()
152 .map(|arity| self.log_len() - arity)
153 // If there is no folding, there are no random queries either
154 .unwrap_or(0)
155 }
156
157 /// Number of folding challenges the verifier sends after receiving the last oracle.
158 pub fn n_final_challenges(&self) -> usize {
159 self.n_fold_rounds() - self.fold_arities.iter().sum::<usize>()
160 }
161
162 /// The reduction arities between each oracle sent to the verifier.
163 pub fn fold_arities(&self) -> &[usize] {
164 &self.fold_arities
165 }
166
167 /// The binary logarithm of the length of the initial oracle.
168 pub fn log_len(&self) -> usize {
169 self.rs_code().log_len() + self.log_batch_size()
170 }
171}
172
173/// This layer allows minimizing the proof size.
174pub fn vcs_optimal_layers_depths_iter<'a, F, FA, VCS>(
175 fri_params: &'a FRIParams<F, FA>,
176 vcs: &'a VCS,
177) -> impl Iterator<Item = usize> + 'a
178where
179 VCS: MerkleTreeScheme<F>,
180 F: BinaryField + ExtensionField<FA>,
181 FA: BinaryField,
182{
183 fri_params
184 .fold_arities()
185 .iter()
186 .scan(fri_params.log_len(), |log_n_cosets, arity| {
187 *log_n_cosets -= arity;
188 Some(vcs.optimal_verify_layer(fri_params.n_test_queries(), *log_n_cosets))
189 })
190}
191
192/// The type of the termination round codeword in the FRI protocol.
193pub type TerminateCodeword<F> = Vec<F>;
194
195/// Calculates the number of test queries required to achieve a target security level.
196///
197/// Throws [`Error::ParameterError`] if the security level is unattainable given the code
198/// parameters.
199pub fn calculate_n_test_queries<F, FEncode>(
200 security_bits: usize,
201 code: &ReedSolomonCode<FEncode>,
202) -> Result<usize, Error>
203where
204 F: BinaryField + ExtensionField<FEncode>,
205 FEncode: BinaryField,
206{
207 let field_size = 2.0_f64.powi(F::N_BITS as i32);
208 let sumcheck_err = (2 * code.log_dim()) as f64 / field_size;
209 // 2 ⋅ ℓ' / |T_{τ}|
210 let folding_err = code.len() as f64 / field_size;
211 // 2^{ℓ' + R} / |T_{τ}|
212 let per_query_err = 0.5 * (1f64 + 2.0f64.powi(-(code.log_inv_rate() as i32)));
213 let allowed_query_err = 2.0_f64.powi(-(security_bits as i32)) - sumcheck_err - folding_err;
214 if allowed_query_err <= 0.0 {
215 return Err(Error::ParameterError);
216 }
217 let n_queries = allowed_query_err.log(per_query_err).ceil() as usize;
218 Ok(n_queries)
219}
220
221/// Heuristic for estimating the optimal FRI folding arity that minimizes proof size.
222///
223/// `log_block_length` is the binary logarithm of the block length of the Reed–Solomon code.
224pub fn estimate_optimal_arity(
225 log_block_length: usize,
226 digest_size: usize,
227 field_size: usize,
228) -> usize {
229 (1..=log_block_length)
230 .map(|arity| {
231 (
232 // for given arity, return a tuple (arity, estimate of query_proof_size).
233 // this estimate is basd on the following approximation of a single
234 // query_proof_size, where $\vartheta$ is the arity: $\big((n-\vartheta) +
235 // (n-2\vartheta) + \ldots\big)\text{digest_size} +
236 // \frac{n-\vartheta}{\vartheta}2^{\vartheta}\text{field_size}.$
237 arity,
238 ((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
239 * (log_block_length - arity)
240 / arity,
241 )
242 })
243 // now scan and terminate the iterator when query_proof_size increases.
244 .scan(None, |old: &mut Option<(usize, usize)>, new| {
245 let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
246 *old = Some(new);
247 should_continue.then_some(new)
248 })
249 .last()
250 .map(|(arity, _)| arity)
251 .unwrap_or(1)
252}
253
254#[cfg(test)]
255mod tests {
256 use assert_matches::assert_matches;
257 use binius_field::{BinaryField32b, BinaryField128b};
258
259 use super::*;
260
261 #[test]
262 fn test_calculate_n_test_queries() {
263 let security_bits = 96;
264 let rs_code = ReedSolomonCode::new(28, 1).unwrap();
265 let n_test_queries =
266 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
267 .unwrap();
268 assert_eq!(n_test_queries, 232);
269
270 let rs_code = ReedSolomonCode::new(28, 2).unwrap();
271 let n_test_queries =
272 calculate_n_test_queries::<BinaryField128b, BinaryField32b>(security_bits, &rs_code)
273 .unwrap();
274 assert_eq!(n_test_queries, 143);
275 }
276
277 #[test]
278 fn test_calculate_n_test_queries_unsatisfiable() {
279 let security_bits = 128;
280 let rs_code = ReedSolomonCode::<BinaryField32b>::new(28, 1).unwrap();
281 assert_matches!(
282 calculate_n_test_queries::<BinaryField128b, _>(security_bits, &rs_code),
283 Err(Error::ParameterError)
284 );
285 }
286
287 #[test]
288 fn test_estimate_optimal_arity() {
289 let field_size = 128;
290 for log_block_length in 22..35 {
291 let digest_size = 256;
292 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
293 }
294
295 for log_block_length in 22..28 {
296 let digest_size = 1024;
297 assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
298 }
299 }
300}