1use binius_compute::{
3 ComputeLayerExecutor,
4 alloc::ComputeAllocator,
5 layer::ComputeLayer,
6 memory::{ComputeMemory, SizedSlice},
7};
8use binius_field::{
9 BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
10 packed::{iter_packed_slice_with_offset, len_packed_slice},
11 unpack_if_possible,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_ntt::AdditiveNTT;
15use binius_utils::{SerializeBytes, bail, checked_arithmetics::log2_strict_usize};
16use bytemuck::zeroed_vec;
17use bytes::BufMut;
18use itertools::izip;
19use tracing::instrument;
20
21use super::{
22 TerminateCodeword,
23 common::{FRIParams, vcs_optimal_layers_depths_iter},
24 error::Error,
25 logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
26};
27use crate::{
28 fiat_shamir::{CanSampleBits, Challenger},
29 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
30 protocols::fri::logging::FRIFoldData,
31 reed_solomon::reed_solomon::ReedSolomonCode,
32 transcript::{ProverTranscript, TranscriptWriter},
33};
34
35#[derive(Debug)]
36pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
37 pub commitment: VCSCommitment,
38 pub committed: VCSCommitted,
39 pub codeword: Vec<P>,
40}
41
42pub fn to_par_scalar_big_chunks<P>(
45 packed_slice: &[P],
46 chunk_size: usize,
47) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
48where
49 P: PackedField,
50{
51 packed_slice
52 .par_chunks(chunk_size / P::WIDTH)
53 .map(|chunk| PackedField::iter_slice(chunk))
54}
55
56pub fn to_par_scalar_small_chunks<P>(
57 packed_slice: &[P],
58 chunk_size: usize,
59) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
60where
61 P: PackedField,
62{
63 (0..packed_slice.len() * P::WIDTH)
64 .into_par_iter()
65 .step_by(chunk_size)
66 .map(move |start_index| {
67 let packed_item = &packed_slice[start_index / P::WIDTH];
68 packed_item
69 .iter()
70 .skip(start_index % P::WIDTH)
71 .take(chunk_size)
72 })
73}
74
75#[instrument(skip_all, level = "debug")]
84pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
85 rs_code: &ReedSolomonCode<FA>,
86 params: &FRIParams<F, FA>,
87 ntt: &NTT,
88 merkle_prover: &MerkleProver,
89 message: &[P],
90) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
91where
92 F: BinaryField,
93 FA: BinaryField,
94 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
95 PA: PackedField<Scalar = FA>,
96 NTT: AdditiveNTT<FA> + Sync,
97 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
98 VCS: MerkleTreeScheme<F>,
99{
100 let n_elems = rs_code.dim() << params.log_batch_size();
101 if message.len() * P::WIDTH != n_elems {
102 bail!(Error::InvalidArgs(
103 "interleaved message length does not match code parameters".to_string()
104 ));
105 }
106
107 commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
108 buffer.copy_from_slice(message)
109 })
110}
111
112pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
121 params: &FRIParams<F, FA>,
122 ntt: &NTT,
123 merkle_prover: &MerkleProver,
124 message_writer: impl FnOnce(&mut [P]),
125) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
126where
127 F: BinaryField,
128 FA: BinaryField,
129 P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
130 PA: PackedField<Scalar = FA>,
131 NTT: AdditiveNTT<FA> + Sync,
132 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
133 VCS: MerkleTreeScheme<F>,
134{
135 let rs_code = params.rs_code();
136 let log_batch_size = params.log_batch_size();
137 let log_elems = rs_code.log_dim() + log_batch_size;
138 if log_elems < P::LOG_WIDTH {
139 todo!("can't handle this case well");
140 }
141
142 let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
143
144 let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
145 tracing::debug_span!(
146 "[task] Sort & Merge",
147 phase = "commit",
148 perfetto_category = "task.main",
149 ?dimensions_data
150 )
151 .in_scope(|| {
152 message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
153 });
154
155 let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
156 tracing::debug_span!(
157 "[task] RS Encode",
158 phase = "commit",
159 perfetto_category = "task.main",
160 ?dimensions_data
161 )
162 .in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
163
164 let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
167
168 let log_len = params.log_len() - coset_log_len;
169 let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
170 let merkle_tree_span = tracing::debug_span!(
171 "[task] Merkle Tree",
172 phase = "commit",
173 perfetto_category = "task.main",
174 dimensions_data = ?dimension_data
175 )
176 .entered();
177 let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
178 let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
179
180 merkle_prover
181 .commit_iterated(iterated_big_chunks, log_len)
182 .map_err(|err| Error::VectorCommit(Box::new(err)))?
183 } else {
184 let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
185
186 merkle_prover
187 .commit_iterated(iterated_small_chunks, log_len)
188 .map_err(|err| Error::VectorCommit(Box::new(err)))?
189 };
190 drop(merkle_tree_span);
191
192 Ok(CommitOutput {
193 commitment: commitment.root,
194 committed: vcs_committed,
195 codeword: encoded,
196 })
197}
198
199pub enum FoldRoundOutput<VCSCommitment> {
200 NoCommitment,
201 Commitment(VCSCommitment),
202}
203
204pub struct SingleRoundCommitted<'b, F, Hal, MerkleProver>
206where
207 F: BinaryField,
208 Hal: ComputeLayer<F>,
209 MerkleProver: MerkleTreeProver<F>,
210{
211 pub host_codeword: Vec<F>,
213 pub device_codeword: <Hal::DevMem as ComputeMemory<F>>::FSliceMut<'b>,
215 pub committed: MerkleProver::Committed,
217}
218
219pub struct FRIFolder<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
220where
221 FA: BinaryField,
222 F: BinaryField,
223 P: PackedField<Scalar = F>,
224 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
225 VCS: MerkleTreeScheme<F>,
226 Hal: ComputeLayer<F>,
227{
228 params: &'a FRIParams<F, FA>,
229 ntt: &'a NTT,
230 merkle_prover: &'a MerkleProver,
231 codeword: &'a [P],
232 codeword_committed: &'a MerkleProver::Committed,
233 round_committed: Vec<SingleRoundCommitted<'b, F, Hal, MerkleProver>>,
234 curr_round: usize,
235 next_commit_round: Option<usize>,
236 unprocessed_challenges: Vec<F>,
237 cl: &'b Hal,
238}
239
240impl<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
241 FRIFolder<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
242where
243 F: TowerField + ExtensionField<FA>,
244 FA: BinaryField,
245 P: PackedField<Scalar = F>,
246 NTT: AdditiveNTT<FA> + Sync,
247 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
248 VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
249 Hal: ComputeLayer<F>,
250{
251 pub fn new(
253 hal: &'b Hal,
254 params: &'a FRIParams<F, FA>,
255 ntt: &'a NTT,
256 merkle_prover: &'a MerkleProver,
257 codeword: &'a [P],
258 committed: &'a MerkleProver::Committed,
259 ) -> Result<Self, Error> {
260 if len_packed_slice(codeword) < 1 << params.log_len() {
261 bail!(Error::InvalidArgs(
262 "Reed–Solomon code length must match interleaved codeword length".to_string(),
263 ));
264 }
265
266 let next_commit_round = params.fold_arities().first().copied();
267 Ok(Self {
268 params,
269 ntt,
270 merkle_prover,
271 codeword,
272 codeword_committed: committed,
273 round_committed: Vec::with_capacity(params.n_oracles()),
274 curr_round: 0,
275 next_commit_round,
276 unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
277 cl: hal,
278 })
279 }
280
281 pub const fn n_rounds(&self) -> usize {
283 self.params.n_fold_rounds()
284 }
285
286 pub const fn curr_round(&self) -> usize {
288 self.curr_round
289 }
290
291 pub fn current_codeword_len(&self) -> usize {
293 match self.round_committed.last() {
294 Some(round) => round.host_codeword.len(),
295 None => len_packed_slice(self.codeword),
296 }
297 }
298
299 fn is_commitment_round(&self) -> bool {
300 self.next_commit_round
301 .is_some_and(|round| round == self.curr_round)
302 }
303
304 pub fn execute_fold_round(
310 &mut self,
311 allocator: &'b impl ComputeAllocator<F, Hal::DevMem>,
312 challenge: F,
313 ) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
314 self.unprocessed_challenges.push(challenge);
315 self.curr_round += 1;
316
317 if !self.is_commitment_round() {
318 return Ok(FoldRoundOutput::NoCommitment);
319 }
320
321 let dimensions_data = match self.round_committed.last() {
322 Some(round) => FRIFoldData::new::<F, FA>(
323 log2_strict_usize(round.host_codeword.len()),
324 0,
325 self.unprocessed_challenges.len(),
326 ),
327 None => FRIFoldData::new::<F, FA>(
328 self.params.rs_code().log_len(),
329 self.params.log_batch_size(),
330 self.unprocessed_challenges.len(),
331 ),
332 };
333
334 let fri_fold_span = tracing::debug_span!(
335 "[task] FRI Fold",
336 phase = "piop_compiler",
337 perfetto_category = "task.main",
338 ?dimensions_data
339 )
340 .entered();
341 let folded_codeword = match self.round_committed.last() {
343 Some(prev_round) => {
344 let mut folded_codeword = allocator.alloc(
347 prev_round.device_codeword.len() / (1 << self.unprocessed_challenges.len()),
348 )?;
349
350 self.cl.execute(|exec| {
351 exec.fri_fold(
352 self.ntt,
353 log2_strict_usize(prev_round.device_codeword.len()),
354 0,
355 &self.unprocessed_challenges,
356 Hal::DevMem::as_const(&prev_round.device_codeword),
357 &mut folded_codeword,
358 )?;
359
360 Ok(vec![])
361 })?;
362
363 folded_codeword
364 }
365 None => {
366 let codeword_len = len_packed_slice(self.codeword);
367 let mut original_codeword = allocator.alloc(codeword_len)?;
368 unpack_if_possible(
369 self.codeword,
370 |scalars| self.cl.copy_h2d(scalars, &mut original_codeword),
371 |_packed| unimplemented!("non-dense packed fields not supported"),
372 )?;
373 let mut folded_codeword = allocator.alloc(
374 1 << (self.params.rs_code().log_len()
375 - (self.unprocessed_challenges.len() - self.params.log_batch_size())),
376 )?;
377 self.cl.execute(|exec| {
378 exec.fri_fold(
379 self.ntt,
380 self.params.rs_code().log_len(),
381 self.params.log_batch_size(),
382 &self.unprocessed_challenges,
383 Hal::DevMem::as_const(&original_codeword),
384 &mut folded_codeword,
385 )?;
386
387 Ok(vec![])
388 })?;
389
390 folded_codeword
391 }
392 };
393 drop(fri_fold_span);
394 self.unprocessed_challenges.clear();
395
396 let mut folded_codeword_host = zeroed_vec(folded_codeword.len());
397 self.cl
398 .copy_d2h(Hal::DevMem::as_const(&folded_codeword), &mut folded_codeword_host)?;
399
400 let coset_size = self
402 .params
403 .fold_arities()
404 .get(self.round_committed.len() + 1)
405 .map(|log| 1 << log)
406 .unwrap_or_else(|| 1 << self.params.n_final_challenges());
407 let dimension_data =
408 MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
409 let merkle_tree_span = tracing::debug_span!(
410 "[task] Merkle Tree",
411 phase = "piop_compiler",
412 perfetto_category = "task.main",
413 dimensions_data = ?dimension_data
414 )
415 .entered();
416 let (commitment, committed) = self
417 .merkle_prover
418 .commit(&folded_codeword_host, coset_size)
419 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
420 drop(merkle_tree_span);
421
422 self.round_committed.push(SingleRoundCommitted {
423 host_codeword: folded_codeword_host,
424 device_codeword: folded_codeword,
425 committed,
426 });
427
428 self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
429 let arity = self.params.fold_arities().get(self.round_committed.len())?;
430 Some(next_commit_round + arity)
431 });
432 Ok(FoldRoundOutput::Commitment(commitment.root))
433 }
434
435 #[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
443 #[allow(clippy::type_complexity)]
444 pub fn finalize(
445 mut self,
446 ) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
447 if self.curr_round != self.n_rounds() {
448 bail!(Error::EarlyProverFinish);
449 }
450
451 let terminate_codeword = self
452 .round_committed
453 .last()
454 .map(|round| round.host_codeword.clone())
455 .unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
456
457 self.unprocessed_challenges.clear();
458
459 let Self {
460 params,
461 codeword,
462 codeword_committed,
463 round_committed,
464 merkle_prover,
465 ..
466 } = self;
467
468 let round_committed = round_committed
469 .into_iter()
470 .map(|round| (round.host_codeword, round.committed))
471 .collect();
472
473 let query_prover = FRIQueryProver {
474 params,
475 codeword,
476 codeword_committed,
477 round_committed,
478 merkle_prover,
479 };
480 Ok((terminate_codeword, query_prover))
481 }
482
483 pub fn finish_proof<Challenger_>(
484 self,
485 transcript: &mut ProverTranscript<Challenger_>,
486 ) -> Result<(), Error>
487 where
488 Challenger_: Challenger,
489 {
490 let (terminate_codeword, query_prover) = self.finalize()?;
491 let mut advice = transcript.decommitment();
492 advice.write_scalar_slice(&terminate_codeword);
493
494 let layers = query_prover.vcs_optimal_layers()?;
495 for layer in layers {
496 advice.write_slice(&layer);
497 }
498
499 let params = query_prover.params;
500
501 for _ in 0..params.n_test_queries() {
502 let index = transcript.sample_bits(params.index_bits()) as usize;
503 query_prover.prove_query(index, transcript.decommitment())?;
504 }
505
506 Ok(())
507 }
508}
509
510pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
512where
513 F: BinaryField,
514 FA: BinaryField,
515 P: PackedField<Scalar = F>,
516 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
517 VCS: MerkleTreeScheme<F>,
518{
519 params: &'a FRIParams<F, FA>,
520 codeword: &'a [P],
521 codeword_committed: &'a MerkleProver::Committed,
522 round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
523 merkle_prover: &'a MerkleProver,
524}
525
526impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
527where
528 F: TowerField + ExtensionField<FA>,
529 FA: BinaryField,
530 P: PackedField<Scalar = F>,
531 MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
532 VCS: MerkleTreeScheme<F>,
533{
534 pub fn n_oracles(&self) -> usize {
536 self.params.n_oracles()
537 }
538
539 #[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
545 pub fn prove_query<B>(
546 &self,
547 mut index: usize,
548 mut advice: TranscriptWriter<B>,
549 ) -> Result<(), Error>
550 where
551 B: BufMut,
552 {
553 let mut arities_and_optimal_layers_depths = self
554 .params
555 .fold_arities()
556 .iter()
557 .copied()
558 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
559
560 let Some((first_fold_arity, first_optimal_layer_depth)) =
561 arities_and_optimal_layers_depths.next()
562 else {
563 return Ok(());
567 };
568
569 prove_coset_opening(
570 self.merkle_prover,
571 self.codeword,
572 self.codeword_committed,
573 index,
574 first_fold_arity,
575 first_optimal_layer_depth,
576 &mut advice,
577 )?;
578
579 for ((codeword, committed), (arity, optimal_layer_depth)) in
580 izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
581 {
582 index >>= arity;
583 prove_coset_opening(
584 self.merkle_prover,
585 codeword,
586 committed,
587 index,
588 arity,
589 optimal_layer_depth,
590 &mut advice,
591 )?;
592 }
593
594 Ok(())
595 }
596
597 pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
598 let committed_iter = std::iter::once(self.codeword_committed)
599 .chain(self.round_committed.iter().map(|(_, committed)| committed));
600
601 committed_iter
602 .zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
603 .map(|(committed, optimal_layer_depth)| {
604 self.merkle_prover
605 .layer(committed, optimal_layer_depth)
606 .map(|layer| layer.to_vec())
607 .map_err(|err| Error::VectorCommit(Box::new(err)))
608 })
609 .collect::<Result<Vec<_>, _>>()
610 }
611}
612
613fn prove_coset_opening<F, P, MTProver, B>(
614 merkle_prover: &MTProver,
615 codeword: &[P],
616 committed: &MTProver::Committed,
617 coset_index: usize,
618 log_coset_size: usize,
619 optimal_layer_depth: usize,
620 advice: &mut TranscriptWriter<B>,
621) -> Result<(), Error>
622where
623 F: TowerField,
624 P: PackedField<Scalar = F>,
625 MTProver: MerkleTreeProver<F>,
626 B: BufMut,
627{
628 let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
629 .take(1 << log_coset_size);
630 advice.write_scalar_iter(values);
631
632 merkle_prover
633 .prove_opening(committed, optimal_layer_depth, coset_index, advice)
634 .map_err(|err| Error::VectorCommit(Box::new(err)))?;
635
636 Ok(())
637}