1use binius_field::{
4 BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
5 packed::{iter_packed_slice_with_offset, len_packed_slice},
6};
7use binius_maybe_rayon::prelude::*;
8use binius_ntt::{AdditiveNTT, fri::fold_interleaved};
9use binius_utils::{SerializeBytes, bail, checked_arithmetics::log2_strict_usize};
10use bytemuck::zeroed_vec;
11use bytes::BufMut;
12use itertools::izip;
13use tracing::instrument;
14
15use super::{
16 TerminateCodeword,
17 common::{FRIParams, vcs_optimal_layers_depths_iter},
18 error::Error,
19 logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
20};
21use crate::{
22 fiat_shamir::{CanSampleBits, Challenger},
23 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
24 protocols::fri::logging::FRIFoldData,
25 reed_solomon::reed_solomon::ReedSolomonCode,
26 transcript::{ProverTranscript, TranscriptWriter},
27};
28
29#[derive(Debug)]
30pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
31 pub commitment: VCSCommitment,
32 pub committed: VCSCommitted,
33 pub codeword: Vec<P>,
34}
35
36pub fn to_par_scalar_big_chunks<P>(
39 packed_slice: &[P],
40 chunk_size: usize,
41) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
42where
43 P: PackedField,
44{
45 packed_slice
46 .par_chunks(chunk_size / P::WIDTH)
47 .map(|chunk| PackedField::iter_slice(chunk))
48}
49
50pub fn to_par_scalar_small_chunks<P>(
51 packed_slice: &[P],
52 chunk_size: usize,
53) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
54where
55 P: PackedField,
56{
57 (0..packed_slice.len() * P::WIDTH)
58 .into_par_iter()
59 .step_by(chunk_size)
60 .map(move |start_index| {
61 let packed_item = &packed_slice[start_index / P::WIDTH];
62 packed_item
63 .iter()
64 .skip(start_index % P::WIDTH)
65 .take(chunk_size)
66 })
67}
68
69#[instrument(skip_all, level = "debug")]
78pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
79 rs_code: &ReedSolomonCode<FA>,
80 params: &FRIParams<F, FA>,
81 ntt: &NTT,
82 merkle_prover: &MerkleProver,
83 message: &[P],
84) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
85where
86 F: BinaryField,
87 FA: BinaryField,
88 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
89 PA: PackedField<Scalar = FA>,
90 NTT: AdditiveNTT<FA> + Sync,
91 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
92 VCS: MerkleTreeScheme<F>,
93{
94 let n_elems = rs_code.dim() << params.log_batch_size();
95 if message.len() * P::WIDTH != n_elems {
96 bail!(Error::InvalidArgs(
97 "interleaved message length does not match code parameters".to_string()
98 ));
99 }
100
101 commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
102 buffer.copy_from_slice(message)
103 })
104}
105
106pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
115 params: &FRIParams<F, FA>,
116 ntt: &NTT,
117 merkle_prover: &MerkleProver,
118 message_writer: impl FnOnce(&mut [P]),
119) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
120where
121 F: BinaryField,
122 FA: BinaryField,
123 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
124 PA: PackedField<Scalar = FA>,
125 NTT: AdditiveNTT<FA> + Sync,
126 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
127 VCS: MerkleTreeScheme<F>,
128{
129 let rs_code = params.rs_code();
130 let log_batch_size = params.log_batch_size();
131 let log_elems = rs_code.log_dim() + log_batch_size;
132 if log_elems < P::LOG_WIDTH {
133 todo!("can't handle this case well");
134 }
135
136 let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
137
138 let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
139 tracing::debug_span!(
140 "[task] Sort & Merge",
141 phase = "commit",
142 perfetto_category = "task.main",
143 ?dimensions_data
144 )
145 .in_scope(|| {
146 message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
147 });
148
149 let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
150 tracing::debug_span!(
151 "[task] RS Encode",
152 phase = "commit",
153 perfetto_category = "task.main",
154 ?dimensions_data
155 )
156 .in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
157
158 let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
161
162 let log_len = params.log_len() - coset_log_len;
163 let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
164 let merkle_tree_span = tracing::debug_span!(
165 "[task] Merkle Tree",
166 phase = "commit",
167 perfetto_category = "task.main",
168 dimensions_data = ?dimension_data
169 )
170 .entered();
171 let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
172 let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
173
174 merkle_prover
175 .commit_iterated(iterated_big_chunks, log_len)
176 .map_err(|err| Error::VectorCommit(Box::new(err)))?
177 } else {
178 let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
179
180 merkle_prover
181 .commit_iterated(iterated_small_chunks, log_len)
182 .map_err(|err| Error::VectorCommit(Box::new(err)))?
183 };
184 drop(merkle_tree_span);
185
186 Ok(CommitOutput {
187 commitment: commitment.root,
188 committed: vcs_committed,
189 codeword: encoded,
190 })
191}
192
193pub enum FoldRoundOutput<VCSCommitment> {
194 NoCommitment,
195 Commitment(VCSCommitment),
196}
197
198pub struct FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
200where
201 FA: BinaryField,
202 F: BinaryField,
203 P: PackedField<Scalar = F>,
204 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
205 VCS: MerkleTreeScheme<F>,
206{
207 params: &'a FRIParams<F, FA>,
208 ntt: &'a NTT,
209 merkle_prover: &'a MerkleProver,
210 codeword: &'a [P],
211 codeword_committed: &'a MerkleProver::Committed,
212 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
213 curr_round: usize,
214 next_commit_round: Option<usize>,
215 unprocessed_challenges: Vec<F>,
216}
217
218impl<'a, F, FA, P, NTT, MerkleProver, VCS> FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
219where
220 F: TowerField + ExtensionField<FA>,
221 FA: BinaryField,
222 P: PackedField<Scalar = F>,
223 NTT: AdditiveNTT<FA> + Sync,
224 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
225 VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
226{
227 pub fn new(
229 params: &'a FRIParams<F, FA>,
230 ntt: &'a NTT,
231 merkle_prover: &'a MerkleProver,
232 committed_codeword: &'a [P],
233 committed: &'a MerkleProver::Committed,
234 ) -> Result<Self, Error> {
235 if len_packed_slice(committed_codeword) < 1 << params.log_len() {
236 bail!(Error::InvalidArgs(
237 "Reed–Solomon code length must match interleaved codeword length".to_string(),
238 ));
239 }
240
241 let next_commit_round = params.fold_arities().first().copied();
242 Ok(Self {
243 params,
244 ntt,
245 merkle_prover,
246 codeword: committed_codeword,
247 codeword_committed: committed,
248 round_committed: Vec::with_capacity(params.n_oracles()),
249 curr_round: 0,
250 next_commit_round,
251 unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
252 })
253 }
254
255 pub const fn n_rounds(&self) -> usize {
257 self.params.n_fold_rounds()
258 }
259
260 pub const fn curr_round(&self) -> usize {
262 self.curr_round
263 }
264
265 pub fn current_codeword_len(&self) -> usize {
267 match self.round_committed.last() {
268 Some((codeword, _)) => codeword.len(),
269 None => len_packed_slice(self.codeword),
270 }
271 }
272
273 fn is_commitment_round(&self) -> bool {
274 self.next_commit_round
275 .is_some_and(|round| round == self.curr_round)
276 }
277
278 pub fn execute_fold_round(
284 &mut self,
285 challenge: F,
286 ) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
287 self.unprocessed_challenges.push(challenge);
288 self.curr_round += 1;
289
290 if !self.is_commitment_round() {
291 return Ok(FoldRoundOutput::NoCommitment);
292 }
293
294 let dimensions_data = match self.round_committed.last() {
295 Some((codeword, _)) => FRIFoldData::new::<F, FA>(
296 log2_strict_usize(codeword.len()),
297 0,
298 self.unprocessed_challenges.len(),
299 ),
300 None => FRIFoldData::new::<F, FA>(
301 self.params.rs_code().log_len(),
302 self.params.log_batch_size(),
303 self.unprocessed_challenges.len(),
304 ),
305 };
306
307 let fri_fold_span = tracing::debug_span!(
308 "[task] FRI Fold",
309 phase = "piop_compiler",
310 perfetto_category = "task.main",
311 ?dimensions_data
312 )
313 .entered();
314 let folded_codeword = match self.round_committed.last() {
316 Some((prev_codeword, _)) => {
317 fold_interleaved(
320 self.ntt,
321 prev_codeword,
322 &self.unprocessed_challenges,
323 log2_strict_usize(prev_codeword.len()),
324 0,
325 )
326 }
327 None => {
328 fold_interleaved(
332 self.ntt,
333 self.codeword,
334 &self.unprocessed_challenges,
335 self.params.rs_code().log_len(),
336 self.params.log_batch_size(),
337 )
338 }
339 };
340 drop(fri_fold_span);
341 self.unprocessed_challenges.clear();
342
343 let coset_size = self
345 .params
346 .fold_arities()
347 .get(self.round_committed.len() + 1)
348 .map(|log| 1 << log)
349 .unwrap_or_else(|| 1 << self.params.n_final_challenges());
350 let dimension_data =
351 MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
352 let merkle_tree_span = tracing::debug_span!(
353 "[task] Merkle Tree",
354 phase = "piop_compiler",
355 perfetto_category = "task.main",
356 dimensions_data = ?dimension_data
357 )
358 .entered();
359 let (commitment, committed) = self
360 .merkle_prover
361 .commit(&folded_codeword, coset_size)
362 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
363 drop(merkle_tree_span);
364
365 self.round_committed.push((folded_codeword, committed));
366
367 self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
368 let arity = self.params.fold_arities().get(self.round_committed.len())?;
369 Some(next_commit_round + arity)
370 });
371 Ok(FoldRoundOutput::Commitment(commitment.root))
372 }
373
374 #[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
382 #[allow(clippy::type_complexity)]
383 pub fn finalize(
384 mut self,
385 ) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
386 if self.curr_round != self.n_rounds() {
387 bail!(Error::EarlyProverFinish);
388 }
389
390 let terminate_codeword = self
391 .round_committed
392 .last()
393 .map(|(codeword, _)| codeword.clone())
394 .unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
395
396 self.unprocessed_challenges.clear();
397
398 let Self {
399 params,
400 codeword,
401 codeword_committed,
402 round_committed,
403 merkle_prover,
404 ..
405 } = self;
406
407 let query_prover = FRIQueryProver {
408 params,
409 codeword,
410 codeword_committed,
411 round_committed,
412 merkle_prover,
413 };
414 Ok((terminate_codeword, query_prover))
415 }
416
417 pub fn finish_proof<Challenger_>(
418 self,
419 transcript: &mut ProverTranscript<Challenger_>,
420 ) -> Result<(), Error>
421 where
422 Challenger_: Challenger,
423 {
424 let (terminate_codeword, query_prover) = self.finalize()?;
425 let mut advice = transcript.decommitment();
426 advice.write_scalar_slice(&terminate_codeword);
427
428 let layers = query_prover.vcs_optimal_layers()?;
429 for layer in layers {
430 advice.write_slice(&layer);
431 }
432
433 let params = query_prover.params;
434
435 for _ in 0..params.n_test_queries() {
436 let index = transcript.sample_bits(params.index_bits()) as usize;
437 query_prover.prove_query(index, transcript.decommitment())?;
438 }
439
440 Ok(())
441 }
442}
443
444pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
446where
447 F: BinaryField,
448 FA: BinaryField,
449 P: PackedField<Scalar = F>,
450 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
451 VCS: MerkleTreeScheme<F>,
452{
453 params: &'a FRIParams<F, FA>,
454 codeword: &'a [P],
455 codeword_committed: &'a MerkleProver::Committed,
456 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
457 merkle_prover: &'a MerkleProver,
458}
459
460impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
461where
462 F: TowerField + ExtensionField<FA>,
463 FA: BinaryField,
464 P: PackedField<Scalar = F>,
465 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
466 VCS: MerkleTreeScheme<F>,
467{
468 pub fn n_oracles(&self) -> usize {
470 self.params.n_oracles()
471 }
472
473 #[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
479 pub fn prove_query<B>(
480 &self,
481 mut index: usize,
482 mut advice: TranscriptWriter<B>,
483 ) -> Result<(), Error>
484 where
485 B: BufMut,
486 {
487 let mut arities_and_optimal_layers_depths = self
488 .params
489 .fold_arities()
490 .iter()
491 .copied()
492 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
493
494 let Some((first_fold_arity, first_optimal_layer_depth)) =
495 arities_and_optimal_layers_depths.next()
496 else {
497 return Ok(());
501 };
502
503 prove_coset_opening(
504 self.merkle_prover,
505 self.codeword,
506 self.codeword_committed,
507 index,
508 first_fold_arity,
509 first_optimal_layer_depth,
510 &mut advice,
511 )?;
512
513 for ((codeword, committed), (arity, optimal_layer_depth)) in
514 izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
515 {
516 index >>= arity;
517 prove_coset_opening(
518 self.merkle_prover,
519 codeword,
520 committed,
521 index,
522 arity,
523 optimal_layer_depth,
524 &mut advice,
525 )?;
526 }
527
528 Ok(())
529 }
530
531 pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
532 let committed_iter = std::iter::once(self.codeword_committed)
533 .chain(self.round_committed.iter().map(|(_, committed)| committed));
534
535 committed_iter
536 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
537 .map(|(committed, optimal_layer_depth)| {
538 self.merkle_prover
539 .layer(committed, optimal_layer_depth)
540 .map(|layer| layer.to_vec())
541 .map_err(|err| Error::VectorCommit(Box::new(err)))
542 })
543 .collect::<Result<Vec<_>, _>>()
544 }
545}
546
547fn prove_coset_opening<F, P, MTProver, B>(
548 merkle_prover: &MTProver,
549 codeword: &[P],
550 committed: &MTProver::Committed,
551 coset_index: usize,
552 log_coset_size: usize,
553 optimal_layer_depth: usize,
554 advice: &mut TranscriptWriter<B>,
555) -> Result<(), Error>
556where
557 F: TowerField,
558 P: PackedField<Scalar = F>,
559 MTProver: MerkleTreeProver<F>,
560 B: BufMut,
561{
562 let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
563 .take(1 << log_coset_size);
564 advice.write_scalar_iter(values);
565
566 merkle_prover
567 .prove_opening(committed, optimal_layer_depth, coset_index, advice)
568 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
569
570 Ok(())
571}