1use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField};
4use binius_hal::{make_portable_backend, ComputationBackend};
5use binius_maybe_rayon::prelude::*;
6use binius_utils::{bail, SerializeBytes};
7use bytemuck::zeroed_vec;
8use bytes::BufMut;
9use itertools::izip;
10use tracing::instrument;
11
12use super::{
13 common::{vcs_optimal_layers_depths_iter, FRIParams},
14 error::Error,
15 TerminateCodeword,
16};
17use crate::{
18 fiat_shamir::{CanSampleBits, Challenger},
19 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
20 protocols::fri::common::{fold_chunk, fold_interleaved_chunk},
21 reed_solomon::reed_solomon::ReedSolomonCode,
22 transcript::{ProverTranscript, TranscriptWriter},
23};
24
25#[instrument(skip_all, level = "debug")]
26pub fn fold_codeword<F, FS>(
27 rs_code: &ReedSolomonCode<FS>,
28 codeword: &[F],
29 round: usize,
31 folding_challenges: &[F],
32) -> Vec<F>
33where
34 F: BinaryField + ExtensionField<FS>,
35 FS: BinaryField,
36{
37 assert_eq!(codeword.len() % (1 << folding_challenges.len()), 0);
39 assert!(round >= folding_challenges.len());
40 assert!(round <= rs_code.log_dim());
41
42 if folding_challenges.is_empty() {
43 return codeword.to_vec();
44 }
45
46 let start_round = round - folding_challenges.len();
47 let chunk_size = 1 << folding_challenges.len();
48
49 codeword
51 .par_chunks(chunk_size)
52 .enumerate()
53 .map_init(
54 || vec![F::default(); chunk_size],
55 |scratch_buffer, (chunk_index, chunk)| {
56 fold_chunk(
57 rs_code,
58 start_round,
59 chunk_index,
60 chunk,
61 folding_challenges,
62 scratch_buffer,
63 )
64 },
65 )
66 .collect()
67}
68
69#[instrument(skip_all, level = "debug")]
78fn fold_interleaved<F, FS>(
79 rs_code: &ReedSolomonCode<FS>,
80 codeword: &[F],
81 challenges: &[F],
82 log_batch_size: usize,
83) -> Vec<F>
84where
85 F: BinaryField + ExtensionField<FS>,
86 FS: BinaryField,
87{
88 assert_eq!(codeword.len(), 1 << (rs_code.log_len() + log_batch_size));
89 assert!(challenges.len() >= log_batch_size);
90
91 let backend = make_portable_backend();
92
93 let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
94 let tensor = backend
95 .tensor_product_full_query(interleave_challenges)
96 .expect("number of challenges is less than 32");
97
98 let fold_chunk_size = 1 << fold_challenges.len();
100 let interleave_chunk_size = 1 << log_batch_size;
101 let chunk_size = fold_chunk_size * interleave_chunk_size;
102 codeword
103 .par_chunks(chunk_size)
104 .enumerate()
105 .map_init(
106 || vec![F::default(); 2 * fold_chunk_size],
107 |scratch_buffer, (i, chunk)| {
108 fold_interleaved_chunk(
109 rs_code,
110 log_batch_size,
111 i,
112 chunk,
113 &tensor,
114 fold_challenges,
115 scratch_buffer,
116 )
117 },
118 )
119 .collect()
120}
121
122#[derive(Debug)]
123pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
124 pub commitment: VCSCommitment,
125 pub committed: VCSCommitted,
126 pub codeword: Vec<P>,
127}
128
129pub fn to_par_scalar_big_chunks<P>(
131 packed_slice: &[P],
132 chunk_size: usize,
133) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
134where
135 P: PackedField,
136{
137 packed_slice
138 .par_chunks(chunk_size / P::WIDTH)
139 .map(|chunk| PackedField::iter_slice(chunk))
140}
141
142pub fn to_par_scalar_small_chunks<P>(
143 packed_slice: &[P],
144 chunk_size: usize,
145) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
146where
147 P: PackedField,
148{
149 (0..packed_slice.len() * P::WIDTH)
150 .into_par_iter()
151 .step_by(chunk_size)
152 .map(move |start_index| {
153 let packed_item = &packed_slice[start_index / P::WIDTH];
154 packed_item
155 .iter()
156 .skip(start_index % P::WIDTH)
157 .take(chunk_size)
158 })
159}
160
161#[instrument(skip_all, level = "debug")]
170pub fn commit_interleaved<F, FA, P, PA, MerkleProver, VCS>(
171 rs_code: &ReedSolomonCode<PA>,
172 params: &FRIParams<F, FA>,
173 merkle_prover: &MerkleProver,
174 message: &[P],
175) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
176where
177 F: BinaryField,
178 FA: BinaryField,
179 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
180 PA: PackedField<Scalar = FA>,
181 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
182 VCS: MerkleTreeScheme<F>,
183{
184 let n_elems = rs_code.dim() << params.log_batch_size();
185 if message.len() * P::WIDTH != n_elems {
186 bail!(Error::InvalidArgs(
187 "interleaved message length does not match code parameters".to_string()
188 ));
189 }
190
191 commit_interleaved_with(rs_code, params, merkle_prover, move |buffer| {
192 buffer.copy_from_slice(message)
193 })
194}
195
196#[instrument(skip_all, level = "debug")]
205pub fn commit_interleaved_with<F, FA, P, PA, MerkleProver, VCS>(
206 rs_code: &ReedSolomonCode<PA>,
207 params: &FRIParams<F, FA>,
208 merkle_prover: &MerkleProver,
209 message_writer: impl FnOnce(&mut [P]),
210) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
211where
212 F: BinaryField,
213 FA: BinaryField,
214 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
215 PA: PackedField<Scalar = FA>,
216 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
217 VCS: MerkleTreeScheme<F>,
218{
219 let log_batch_size = params.log_batch_size();
220 let log_elems = rs_code.log_dim() + log_batch_size;
221 if log_elems < P::LOG_WIDTH {
222 todo!("can't handle this case well");
223 }
224
225 let mut encoded = tracing::debug_span!("allocate codeword")
226 .in_scope(|| zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate())));
227 message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
228 rs_code.encode_ext_batch_inplace(&mut encoded, log_batch_size)?;
229
230 let coset_log_len = params
232 .fold_arities()
233 .first()
234 .copied()
235 .unwrap_or_else(|| rs_code.log_inv_rate());
236
237 let log_len = params.log_len() - coset_log_len;
238
239 let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
240 let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
241
242 merkle_prover
243 .commit_iterated(iterated_big_chunks, log_len)
244 .map_err(|err| Error::VectorCommit(Box::new(err)))?
245 } else {
246 let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
247
248 merkle_prover
249 .commit_iterated(iterated_small_chunks, log_len)
250 .map_err(|err| Error::VectorCommit(Box::new(err)))?
251 };
252
253 Ok(CommitOutput {
254 commitment: commitment.root,
255 committed: vcs_committed,
256 codeword: encoded,
257 })
258}
259
260pub enum FoldRoundOutput<VCSCommitment> {
261 NoCommitment,
262 Commitment(VCSCommitment),
263}
264
265pub struct FRIFolder<'a, F, FA, MerkleProver, VCS>
267where
268 FA: BinaryField,
269 F: BinaryField,
270 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
271 VCS: MerkleTreeScheme<F>,
272{
273 params: &'a FRIParams<F, FA>,
274 merkle_prover: &'a MerkleProver,
275 codeword: &'a [F],
276 codeword_committed: &'a MerkleProver::Committed,
277 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
278 curr_round: usize,
279 next_commit_round: Option<usize>,
280 unprocessed_challenges: Vec<F>,
281}
282
283impl<'a, F, FA, MerkleProver, VCS> FRIFolder<'a, F, FA, MerkleProver, VCS>
284where
285 F: TowerField + ExtensionField<FA>,
286 FA: BinaryField,
287 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
288 VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
289{
290 pub fn new(
292 params: &'a FRIParams<F, FA>,
293 merkle_prover: &'a MerkleProver,
294 committed_codeword: &'a [F],
295 committed: &'a MerkleProver::Committed,
296 ) -> Result<Self, Error> {
297 if committed_codeword.len() != 1 << params.log_len() {
298 bail!(Error::InvalidArgs(
299 "Reed–Solomon code length must match interleaved codeword length".to_string(),
300 ));
301 }
302
303 let next_commit_round = params.fold_arities().first().copied();
304 Ok(Self {
305 params,
306 merkle_prover,
307 codeword: committed_codeword,
308 codeword_committed: committed,
309 round_committed: Vec::with_capacity(params.n_oracles()),
310 curr_round: 0,
311 next_commit_round,
312 unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
313 })
314 }
315
316 pub const fn n_rounds(&self) -> usize {
318 self.params.n_fold_rounds()
319 }
320
321 pub const fn curr_round(&self) -> usize {
323 self.curr_round
324 }
325
326 fn is_commitment_round(&self) -> bool {
327 self.next_commit_round
328 .is_some_and(|round| round == self.curr_round)
329 }
330
331 #[instrument(skip_all, name = "fri::FRIFolder::execute_fold_round", level = "debug")]
336 pub fn execute_fold_round(
337 &mut self,
338 challenge: F,
339 ) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
340 self.unprocessed_challenges.push(challenge);
341 self.curr_round += 1;
342
343 if !self.is_commitment_round() {
344 return Ok(FoldRoundOutput::NoCommitment);
345 }
346
347 let folded_codeword = match self.round_committed.last() {
349 Some((prev_codeword, _)) => {
350 fold_codeword(
353 self.params.rs_code(),
354 prev_codeword,
355 self.curr_round - self.params.log_batch_size(),
356 &self.unprocessed_challenges,
357 )
358 }
359 None => {
360 fold_interleaved(
364 self.params.rs_code(),
365 self.codeword,
366 &self.unprocessed_challenges,
367 self.params.log_batch_size(),
368 )
369 }
370 };
371 self.unprocessed_challenges.clear();
372
373 let coset_size = self
375 .params
376 .fold_arities()
377 .get(self.round_committed.len() + 1)
378 .map(|log| 1 << log)
379 .unwrap_or_else(|| self.params.rs_code().inv_rate());
380
381 let (commitment, committed) = self
382 .merkle_prover
383 .commit(&folded_codeword, coset_size)
384 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
385
386 self.round_committed.push((folded_codeword, committed));
387
388 self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
389 let arity = self.params.fold_arities().get(self.round_committed.len())?;
390 Some(next_commit_round + arity)
391 });
392 Ok(FoldRoundOutput::Commitment(commitment.root))
393 }
394
395 #[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
403 #[allow(clippy::type_complexity)]
404 pub fn finalize(
405 mut self,
406 ) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, MerkleProver, VCS>), Error> {
407 if self.curr_round != self.n_rounds() {
408 bail!(Error::EarlyProverFinish);
409 }
410
411 let terminate_codeword = self
412 .round_committed
413 .last()
414 .map(|(codeword, _)| codeword.clone())
415 .unwrap_or_else(|| self.codeword.to_vec());
416
417 self.unprocessed_challenges.clear();
418
419 let Self {
420 params,
421 codeword,
422 codeword_committed,
423 round_committed,
424 merkle_prover,
425 ..
426 } = self;
427
428 let query_prover = FRIQueryProver {
429 params,
430 codeword,
431 codeword_committed,
432 round_committed,
433 merkle_prover,
434 };
435 Ok((terminate_codeword, query_prover))
436 }
437
438 pub fn finish_proof<Challenger_>(
439 self,
440 transcript: &mut ProverTranscript<Challenger_>,
441 ) -> Result<(), Error>
442 where
443 Challenger_: Challenger,
444 {
445 let (terminate_codeword, query_prover) = self.finalize()?;
446 let mut advice = transcript.decommitment();
447 advice.write_scalar_slice(&terminate_codeword);
448
449 let layers = query_prover.vcs_optimal_layers()?;
450 for layer in layers {
451 advice.write_slice(&layer);
452 }
453
454 let params = query_prover.params;
455
456 for _ in 0..params.n_test_queries() {
457 let index = transcript.sample_bits(params.index_bits());
458 query_prover.prove_query(index, transcript.decommitment())?;
459 }
460
461 Ok(())
462 }
463}
464
465pub struct FRIQueryProver<'a, F, FA, MerkleProver, VCS>
467where
468 F: BinaryField,
469 FA: BinaryField,
470 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
471 VCS: MerkleTreeScheme<F>,
472{
473 params: &'a FRIParams<F, FA>,
474 codeword: &'a [F],
475 codeword_committed: &'a MerkleProver::Committed,
476 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
477 merkle_prover: &'a MerkleProver,
478}
479
480impl<F, FA, MerkleProver, VCS> FRIQueryProver<'_, F, FA, MerkleProver, VCS>
481where
482 F: TowerField + ExtensionField<FA>,
483 FA: BinaryField,
484 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
485 VCS: MerkleTreeScheme<F>,
486{
487 pub fn n_oracles(&self) -> usize {
489 self.params.n_oracles()
490 }
491
492 #[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
498 pub fn prove_query<B>(
499 &self,
500 mut index: usize,
501 mut advice: TranscriptWriter<B>,
502 ) -> Result<(), Error>
503 where
504 B: BufMut,
505 {
506 let mut arities_and_optimal_layers_depths = self
507 .params
508 .fold_arities()
509 .iter()
510 .copied()
511 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
512
513 let Some((first_fold_arity, first_optimal_layer_depth)) =
514 arities_and_optimal_layers_depths.next()
515 else {
516 return Ok(());
520 };
521
522 prove_coset_opening(
523 self.merkle_prover,
524 self.codeword,
525 self.codeword_committed,
526 index,
527 first_fold_arity,
528 first_optimal_layer_depth,
529 &mut advice,
530 )?;
531
532 for ((codeword, committed), (arity, optimal_layer_depth)) in
533 izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
534 {
535 index >>= arity;
536 prove_coset_opening(
537 self.merkle_prover,
538 codeword,
539 committed,
540 index,
541 arity,
542 optimal_layer_depth,
543 &mut advice,
544 )?;
545 }
546
547 Ok(())
548 }
549
550 pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
551 let committed_iter = std::iter::once(self.codeword_committed)
552 .chain(self.round_committed.iter().map(|(_, committed)| committed));
553
554 committed_iter
555 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
556 .map(|(committed, optimal_layer_depth)| {
557 self.merkle_prover
558 .layer(committed, optimal_layer_depth)
559 .map(|layer| layer.to_vec())
560 .map_err(|err| Error::VectorCommit(Box::new(err)))
561 })
562 .collect::<Result<Vec<_>, _>>()
563 }
564}
565
566fn prove_coset_opening<F, MTProver, B>(
567 merkle_prover: &MTProver,
568 codeword: &[F],
569 committed: &MTProver::Committed,
570 coset_index: usize,
571 log_coset_size: usize,
572 optimal_layer_depth: usize,
573 advice: &mut TranscriptWriter<B>,
574) -> Result<(), Error>
575where
576 F: TowerField,
577 MTProver: MerkleTreeProver<F>,
578 B: BufMut,
579{
580 let values = &codeword[(coset_index << log_coset_size)..((coset_index + 1) << log_coset_size)];
581 advice.write_scalar_slice(values);
582
583 merkle_prover
584 .prove_opening(committed, optimal_layer_depth, coset_index, advice)
585 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
586
587 Ok(())
588}