1use binius_field::{
4 packed::{iter_packed_slice_with_offset, len_packed_slice},
5 BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
6};
7use binius_math::MultilinearQuery;
8use binius_maybe_rayon::prelude::*;
9use binius_ntt::AdditiveNTT;
10use binius_utils::{bail, checked_arithmetics::log2_strict_usize, SerializeBytes};
11use bytemuck::zeroed_vec;
12use bytes::BufMut;
13use itertools::izip;
14use tracing::instrument;
15
16use super::{
17 common::{vcs_optimal_layers_depths_iter, FRIParams},
18 error::Error,
19 logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
20 TerminateCodeword,
21};
22use crate::{
23 fiat_shamir::{CanSampleBits, Challenger},
24 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
25 protocols::fri::{common::fold_interleaved_chunk, logging::FRIFoldData},
26 reed_solomon::reed_solomon::ReedSolomonCode,
27 transcript::{ProverTranscript, TranscriptWriter},
28};
29
30#[instrument(skip_all, level = "debug")]
44pub fn fold_interleaved<F, FS, NTT, P>(
45 ntt: &NTT,
46 codeword: &[P],
47 challenges: &[F],
48 log_len: usize,
49 log_batch_size: usize,
50) -> Vec<F>
51where
52 F: BinaryField + ExtensionField<FS>,
53 FS: BinaryField,
54 NTT: AdditiveNTT<FS> + Sync,
55 P: PackedField<Scalar = F>,
56{
57 assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH));
58 assert!(challenges.len() >= log_batch_size);
59
60 let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
61 let tensor = MultilinearQuery::expand(interleave_challenges);
62
63 let fold_chunk_size = 1 << fold_challenges.len();
65 let chunk_size = 1 << challenges.len().saturating_sub(P::LOG_WIDTH);
66 codeword
67 .par_chunks(chunk_size)
68 .enumerate()
69 .map_init(
70 || vec![F::default(); fold_chunk_size],
71 |scratch_buffer, (i, chunk)| {
72 fold_interleaved_chunk(
73 ntt,
74 log_len,
75 log_batch_size,
76 i,
77 chunk,
78 tensor.expansion(),
79 fold_challenges,
80 scratch_buffer,
81 )
82 },
83 )
84 .collect()
85}
86
87#[derive(Debug)]
88pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
89 pub commitment: VCSCommitment,
90 pub committed: VCSCommitted,
91 pub codeword: Vec<P>,
92}
93
94pub fn to_par_scalar_big_chunks<P>(
96 packed_slice: &[P],
97 chunk_size: usize,
98) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
99where
100 P: PackedField,
101{
102 packed_slice
103 .par_chunks(chunk_size / P::WIDTH)
104 .map(|chunk| PackedField::iter_slice(chunk))
105}
106
107pub fn to_par_scalar_small_chunks<P>(
108 packed_slice: &[P],
109 chunk_size: usize,
110) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
111where
112 P: PackedField,
113{
114 (0..packed_slice.len() * P::WIDTH)
115 .into_par_iter()
116 .step_by(chunk_size)
117 .map(move |start_index| {
118 let packed_item = &packed_slice[start_index / P::WIDTH];
119 packed_item
120 .iter()
121 .skip(start_index % P::WIDTH)
122 .take(chunk_size)
123 })
124}
125
126#[instrument(skip_all, level = "debug")]
135pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
136 rs_code: &ReedSolomonCode<FA>,
137 params: &FRIParams<F, FA>,
138 ntt: &NTT,
139 merkle_prover: &MerkleProver,
140 message: &[P],
141) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
142where
143 F: BinaryField,
144 FA: BinaryField,
145 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
146 PA: PackedField<Scalar = FA>,
147 NTT: AdditiveNTT<FA> + Sync,
148 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
149 VCS: MerkleTreeScheme<F>,
150{
151 let n_elems = rs_code.dim() << params.log_batch_size();
152 if message.len() * P::WIDTH != n_elems {
153 bail!(Error::InvalidArgs(
154 "interleaved message length does not match code parameters".to_string()
155 ));
156 }
157
158 commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
159 buffer.copy_from_slice(message)
160 })
161}
162
163pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
172 params: &FRIParams<F, FA>,
173 ntt: &NTT,
174 merkle_prover: &MerkleProver,
175 message_writer: impl FnOnce(&mut [P]),
176) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
177where
178 F: BinaryField,
179 FA: BinaryField,
180 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
181 PA: PackedField<Scalar = FA>,
182 NTT: AdditiveNTT<FA> + Sync,
183 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
184 VCS: MerkleTreeScheme<F>,
185{
186 let rs_code = params.rs_code();
187 let log_batch_size = params.log_batch_size();
188 let log_elems = rs_code.log_dim() + log_batch_size;
189 if log_elems < P::LOG_WIDTH {
190 todo!("can't handle this case well");
191 }
192
193 let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
194
195 let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
196 tracing::debug_span!("[task] Sort & Merge", phase = "commit", perfetto_category = "task.main", dimensions_data = ?dimensions_data)
197 .in_scope(|| {
198 message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
199 });
200
201 let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
202 tracing::debug_span!("[task] RS Encode", phase = "commit", perfetto_category = "task.main", dimensions_data = ?dimensions_data)
203 .in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
204
205 let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
207
208 let log_len = params.log_len() - coset_log_len;
209 let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
210 let merkle_tree_span = tracing::debug_span!(
211 "[task] Merkle Tree",
212 phase = "commit",
213 perfetto_category = "task.main",
214 dimensions_data = ?dimension_data
215 )
216 .entered();
217 let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
218 let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
219
220 merkle_prover
221 .commit_iterated(iterated_big_chunks, log_len)
222 .map_err(|err| Error::VectorCommit(Box::new(err)))?
223 } else {
224 let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
225
226 merkle_prover
227 .commit_iterated(iterated_small_chunks, log_len)
228 .map_err(|err| Error::VectorCommit(Box::new(err)))?
229 };
230 drop(merkle_tree_span);
231
232 Ok(CommitOutput {
233 commitment: commitment.root,
234 committed: vcs_committed,
235 codeword: encoded,
236 })
237}
238
239pub enum FoldRoundOutput<VCSCommitment> {
240 NoCommitment,
241 Commitment(VCSCommitment),
242}
243
244pub struct FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
246where
247 FA: BinaryField,
248 F: BinaryField,
249 P: PackedField<Scalar = F>,
250 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
251 VCS: MerkleTreeScheme<F>,
252{
253 params: &'a FRIParams<F, FA>,
254 ntt: &'a NTT,
255 merkle_prover: &'a MerkleProver,
256 codeword: &'a [P],
257 codeword_committed: &'a MerkleProver::Committed,
258 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
259 curr_round: usize,
260 next_commit_round: Option<usize>,
261 unprocessed_challenges: Vec<F>,
262}
263
264impl<'a, F, FA, P, NTT, MerkleProver, VCS> FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
265where
266 F: TowerField + ExtensionField<FA>,
267 FA: BinaryField,
268 P: PackedField<Scalar = F>,
269 NTT: AdditiveNTT<FA> + Sync,
270 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
271 VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
272{
273 pub fn new(
275 params: &'a FRIParams<F, FA>,
276 ntt: &'a NTT,
277 merkle_prover: &'a MerkleProver,
278 committed_codeword: &'a [P],
279 committed: &'a MerkleProver::Committed,
280 ) -> Result<Self, Error> {
281 if len_packed_slice(committed_codeword) < 1 << params.log_len() {
282 bail!(Error::InvalidArgs(
283 "Reed–Solomon code length must match interleaved codeword length".to_string(),
284 ));
285 }
286
287 let next_commit_round = params.fold_arities().first().copied();
288 Ok(Self {
289 params,
290 ntt,
291 merkle_prover,
292 codeword: committed_codeword,
293 codeword_committed: committed,
294 round_committed: Vec::with_capacity(params.n_oracles()),
295 curr_round: 0,
296 next_commit_round,
297 unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
298 })
299 }
300
301 pub const fn n_rounds(&self) -> usize {
303 self.params.n_fold_rounds()
304 }
305
306 pub const fn curr_round(&self) -> usize {
308 self.curr_round
309 }
310
311 pub fn current_codeword_len(&self) -> usize {
313 match self.round_committed.last() {
314 Some((codeword, _)) => codeword.len(),
315 None => len_packed_slice(self.codeword),
316 }
317 }
318
319 fn is_commitment_round(&self) -> bool {
320 self.next_commit_round
321 .is_some_and(|round| round == self.curr_round)
322 }
323
324 pub fn execute_fold_round(
329 &mut self,
330 challenge: F,
331 ) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
332 self.unprocessed_challenges.push(challenge);
333 self.curr_round += 1;
334
335 if !self.is_commitment_round() {
336 return Ok(FoldRoundOutput::NoCommitment);
337 }
338
339 let dimensions_data = match self.round_committed.last() {
340 Some((codeword, _)) => FRIFoldData::new(
341 log2_strict_usize(codeword.len()),
342 0,
343 self.unprocessed_challenges.len(),
344 ),
345 None => FRIFoldData::new(
346 self.params.rs_code().log_len(),
347 self.params.log_batch_size(),
348 self.unprocessed_challenges.len(),
349 ),
350 };
351
352 let fri_fold_span = tracing::debug_span!(
353 "[task] FRI Fold",
354 phase = "piop_compiler",
355 perfetto_category = "task.main",
356 dimensions_data = ?dimensions_data
357 )
358 .entered();
359 let folded_codeword = match self.round_committed.last() {
361 Some((prev_codeword, _)) => {
362 fold_interleaved(
365 self.ntt,
366 prev_codeword,
367 &self.unprocessed_challenges,
368 log2_strict_usize(prev_codeword.len()),
369 0,
370 )
371 }
372 None => {
373 fold_interleaved(
377 self.ntt,
378 self.codeword,
379 &self.unprocessed_challenges,
380 self.params.rs_code().log_len(),
381 self.params.log_batch_size(),
382 )
383 }
384 };
385 drop(fri_fold_span);
386 self.unprocessed_challenges.clear();
387
388 let coset_size = self
390 .params
391 .fold_arities()
392 .get(self.round_committed.len() + 1)
393 .map(|log| 1 << log)
394 .unwrap_or_else(|| 1 << self.params.n_final_challenges());
395 let dimension_data =
396 MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
397 let merkle_tree_span = tracing::debug_span!(
398 "[task] Merkle Tree",
399 phase = "piop_compiler",
400 perfetto_category = "task.main",
401 dimensions_data = ?dimension_data
402 )
403 .entered();
404 let (commitment, committed) = self
405 .merkle_prover
406 .commit(&folded_codeword, coset_size)
407 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
408 drop(merkle_tree_span);
409
410 self.round_committed.push((folded_codeword, committed));
411
412 self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
413 let arity = self.params.fold_arities().get(self.round_committed.len())?;
414 Some(next_commit_round + arity)
415 });
416 Ok(FoldRoundOutput::Commitment(commitment.root))
417 }
418
419 #[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
427 #[allow(clippy::type_complexity)]
428 pub fn finalize(
429 mut self,
430 ) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
431 if self.curr_round != self.n_rounds() {
432 bail!(Error::EarlyProverFinish);
433 }
434
435 let terminate_codeword = self
436 .round_committed
437 .last()
438 .map(|(codeword, _)| codeword.clone())
439 .unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
440
441 self.unprocessed_challenges.clear();
442
443 let Self {
444 params,
445 codeword,
446 codeword_committed,
447 round_committed,
448 merkle_prover,
449 ..
450 } = self;
451
452 let query_prover = FRIQueryProver {
453 params,
454 codeword,
455 codeword_committed,
456 round_committed,
457 merkle_prover,
458 };
459 Ok((terminate_codeword, query_prover))
460 }
461
462 pub fn finish_proof<Challenger_>(
463 self,
464 transcript: &mut ProverTranscript<Challenger_>,
465 ) -> Result<(), Error>
466 where
467 Challenger_: Challenger,
468 {
469 let (terminate_codeword, query_prover) = self.finalize()?;
470 let mut advice = transcript.decommitment();
471 advice.write_scalar_slice(&terminate_codeword);
472
473 let layers = query_prover.vcs_optimal_layers()?;
474 for layer in layers {
475 advice.write_slice(&layer);
476 }
477
478 let params = query_prover.params;
479
480 for _ in 0..params.n_test_queries() {
481 let index = transcript.sample_bits(params.index_bits()) as usize;
482 query_prover.prove_query(index, transcript.decommitment())?;
483 }
484
485 Ok(())
486 }
487}
488
489pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
491where
492 F: BinaryField,
493 FA: BinaryField,
494 P: PackedField<Scalar = F>,
495 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
496 VCS: MerkleTreeScheme<F>,
497{
498 params: &'a FRIParams<F, FA>,
499 codeword: &'a [P],
500 codeword_committed: &'a MerkleProver::Committed,
501 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
502 merkle_prover: &'a MerkleProver,
503}
504
505impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
506where
507 F: TowerField + ExtensionField<FA>,
508 FA: BinaryField,
509 P: PackedField<Scalar = F>,
510 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
511 VCS: MerkleTreeScheme<F>,
512{
513 pub fn n_oracles(&self) -> usize {
515 self.params.n_oracles()
516 }
517
518 #[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
524 pub fn prove_query<B>(
525 &self,
526 mut index: usize,
527 mut advice: TranscriptWriter<B>,
528 ) -> Result<(), Error>
529 where
530 B: BufMut,
531 {
532 let mut arities_and_optimal_layers_depths = self
533 .params
534 .fold_arities()
535 .iter()
536 .copied()
537 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
538
539 let Some((first_fold_arity, first_optimal_layer_depth)) =
540 arities_and_optimal_layers_depths.next()
541 else {
542 return Ok(());
546 };
547
548 prove_coset_opening(
549 self.merkle_prover,
550 self.codeword,
551 self.codeword_committed,
552 index,
553 first_fold_arity,
554 first_optimal_layer_depth,
555 &mut advice,
556 )?;
557
558 for ((codeword, committed), (arity, optimal_layer_depth)) in
559 izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
560 {
561 index >>= arity;
562 prove_coset_opening(
563 self.merkle_prover,
564 codeword,
565 committed,
566 index,
567 arity,
568 optimal_layer_depth,
569 &mut advice,
570 )?;
571 }
572
573 Ok(())
574 }
575
576 pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
577 let committed_iter = std::iter::once(self.codeword_committed)
578 .chain(self.round_committed.iter().map(|(_, committed)| committed));
579
580 committed_iter
581 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
582 .map(|(committed, optimal_layer_depth)| {
583 self.merkle_prover
584 .layer(committed, optimal_layer_depth)
585 .map(|layer| layer.to_vec())
586 .map_err(|err| Error::VectorCommit(Box::new(err)))
587 })
588 .collect::<Result<Vec<_>, _>>()
589 }
590}
591
592fn prove_coset_opening<F, P, MTProver, B>(
593 merkle_prover: &MTProver,
594 codeword: &[P],
595 committed: &MTProver::Committed,
596 coset_index: usize,
597 log_coset_size: usize,
598 optimal_layer_depth: usize,
599 advice: &mut TranscriptWriter<B>,
600) -> Result<(), Error>
601where
602 F: TowerField,
603 P: PackedField<Scalar = F>,
604 MTProver: MerkleTreeProver<F>,
605 B: BufMut,
606{
607 let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
608 .take(1 << log_coset_size);
609 advice.write_scalar_iter(values);
610
611 merkle_prover
612 .prove_opening(committed, optimal_layer_depth, coset_index, advice)
613 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
614
615 Ok(())
616}