binius_core/protocols/fri/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
5	packed::{iter_packed_slice_with_offset, len_packed_slice},
6};
7use binius_maybe_rayon::prelude::*;
8use binius_ntt::{AdditiveNTT, fri::fold_interleaved};
9use binius_utils::{SerializeBytes, bail, checked_arithmetics::log2_strict_usize};
10use bytemuck::zeroed_vec;
11use bytes::BufMut;
12use itertools::izip;
13use tracing::instrument;
14
15use super::{
16	TerminateCodeword,
17	common::{FRIParams, vcs_optimal_layers_depths_iter},
18	error::Error,
19	logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
20};
21use crate::{
22	fiat_shamir::{CanSampleBits, Challenger},
23	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
24	protocols::fri::logging::FRIFoldData,
25	reed_solomon::reed_solomon::ReedSolomonCode,
26	transcript::{ProverTranscript, TranscriptWriter},
27};
28
29#[derive(Debug)]
30pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
31	pub commitment: VCSCommitment,
32	pub committed: VCSCommitted,
33	pub codeword: Vec<P>,
34}
35
36/// Creates a parallel iterator over scalars of subfield elementsAssumes chunk_size to be a power of
37/// two
38pub fn to_par_scalar_big_chunks<P>(
39	packed_slice: &[P],
40	chunk_size: usize,
41) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
42where
43	P: PackedField,
44{
45	packed_slice
46		.par_chunks(chunk_size / P::WIDTH)
47		.map(|chunk| PackedField::iter_slice(chunk))
48}
49
50pub fn to_par_scalar_small_chunks<P>(
51	packed_slice: &[P],
52	chunk_size: usize,
53) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
54where
55	P: PackedField,
56{
57	(0..packed_slice.len() * P::WIDTH)
58		.into_par_iter()
59		.step_by(chunk_size)
60		.map(move |start_index| {
61			let packed_item = &packed_slice[start_index / P::WIDTH];
62			packed_item
63				.iter()
64				.skip(start_index % P::WIDTH)
65				.take(chunk_size)
66		})
67}
68
69/// Encodes and commits the input message.
70///
71/// ## Arguments
72///
73/// * `rs_code` - the Reed-Solomon code to use for encoding
74/// * `params` - common FRI protocol parameters.
75/// * `merkle_prover` - the merke tree prover to use for committing
76/// * `message` - the interleaved message to encode and commit
77#[instrument(skip_all, level = "debug")]
78pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
79	rs_code: &ReedSolomonCode<FA>,
80	params: &FRIParams<F, FA>,
81	ntt: &NTT,
82	merkle_prover: &MerkleProver,
83	message: &[P],
84) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
85where
86	F: BinaryField,
87	FA: BinaryField,
88	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
89	PA: PackedField<Scalar = FA>,
90	NTT: AdditiveNTT<FA> + Sync,
91	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
92	VCS: MerkleTreeScheme<F>,
93{
94	let n_elems = rs_code.dim() << params.log_batch_size();
95	if message.len() * P::WIDTH != n_elems {
96		bail!(Error::InvalidArgs(
97			"interleaved message length does not match code parameters".to_string()
98		));
99	}
100
101	commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
102		buffer.copy_from_slice(message)
103	})
104}
105
106/// Encodes and commits the input message with a closure for writing the message.
107///
108/// ## Arguments
109///
110/// * `rs_code` - the Reed-Solomon code to use for encoding
111/// * `params` - common FRI protocol parameters.
112/// * `merkle_prover` - the Merkle tree prover to use for committing
113/// * `message_writer` - a closure that writes the interleaved message to encode and commit
114pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
115	params: &FRIParams<F, FA>,
116	ntt: &NTT,
117	merkle_prover: &MerkleProver,
118	message_writer: impl FnOnce(&mut [P]),
119) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
120where
121	F: BinaryField,
122	FA: BinaryField,
123	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
124	PA: PackedField<Scalar = FA>,
125	NTT: AdditiveNTT<FA> + Sync,
126	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
127	VCS: MerkleTreeScheme<F>,
128{
129	let rs_code = params.rs_code();
130	let log_batch_size = params.log_batch_size();
131	let log_elems = rs_code.log_dim() + log_batch_size;
132	if log_elems < P::LOG_WIDTH {
133		todo!("can't handle this case well");
134	}
135
136	let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
137
138	let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
139	tracing::debug_span!(
140		"[task] Sort & Merge",
141		phase = "commit",
142		perfetto_category = "task.main",
143		?dimensions_data
144	)
145	.in_scope(|| {
146		message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
147	});
148
149	let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
150	tracing::debug_span!(
151		"[task] RS Encode",
152		phase = "commit",
153		perfetto_category = "task.main",
154		?dimensions_data
155	)
156	.in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
157
158	// Take the first arity as coset_log_len, or use the value such that the number of leaves equals
159	// 1 << log_inv_rate if arities is empty
160	let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
161
162	let log_len = params.log_len() - coset_log_len;
163	let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
164	let merkle_tree_span = tracing::debug_span!(
165		"[task] Merkle Tree",
166		phase = "commit",
167		perfetto_category = "task.main",
168		dimensions_data = ?dimension_data
169	)
170	.entered();
171	let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
172		let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
173
174		merkle_prover
175			.commit_iterated(iterated_big_chunks, log_len)
176			.map_err(|err| Error::VectorCommit(Box::new(err)))?
177	} else {
178		let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
179
180		merkle_prover
181			.commit_iterated(iterated_small_chunks, log_len)
182			.map_err(|err| Error::VectorCommit(Box::new(err)))?
183	};
184	drop(merkle_tree_span);
185
186	Ok(CommitOutput {
187		commitment: commitment.root,
188		committed: vcs_committed,
189		codeword: encoded,
190	})
191}
192
193pub enum FoldRoundOutput<VCSCommitment> {
194	NoCommitment,
195	Commitment(VCSCommitment),
196}
197
198/// A stateful prover for the FRI fold phase.
199pub struct FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
200where
201	FA: BinaryField,
202	F: BinaryField,
203	P: PackedField<Scalar = F>,
204	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
205	VCS: MerkleTreeScheme<F>,
206{
207	params: &'a FRIParams<F, FA>,
208	ntt: &'a NTT,
209	merkle_prover: &'a MerkleProver,
210	codeword: &'a [P],
211	codeword_committed: &'a MerkleProver::Committed,
212	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
213	curr_round: usize,
214	next_commit_round: Option<usize>,
215	unprocessed_challenges: Vec<F>,
216}
217
218impl<'a, F, FA, P, NTT, MerkleProver, VCS> FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
219where
220	F: TowerField + ExtensionField<FA>,
221	FA: BinaryField,
222	P: PackedField<Scalar = F>,
223	NTT: AdditiveNTT<FA> + Sync,
224	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
225	VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
226{
227	/// Constructs a new folder.
228	pub fn new(
229		params: &'a FRIParams<F, FA>,
230		ntt: &'a NTT,
231		merkle_prover: &'a MerkleProver,
232		committed_codeword: &'a [P],
233		committed: &'a MerkleProver::Committed,
234	) -> Result<Self, Error> {
235		if len_packed_slice(committed_codeword) < 1 << params.log_len() {
236			bail!(Error::InvalidArgs(
237				"Reed–Solomon code length must match interleaved codeword length".to_string(),
238			));
239		}
240
241		let next_commit_round = params.fold_arities().first().copied();
242		Ok(Self {
243			params,
244			ntt,
245			merkle_prover,
246			codeword: committed_codeword,
247			codeword_committed: committed,
248			round_committed: Vec::with_capacity(params.n_oracles()),
249			curr_round: 0,
250			next_commit_round,
251			unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
252		})
253	}
254
255	/// Number of fold rounds, including the final fold.
256	pub const fn n_rounds(&self) -> usize {
257		self.params.n_fold_rounds()
258	}
259
260	/// Number of times `execute_fold_round` has been called.
261	pub const fn curr_round(&self) -> usize {
262		self.curr_round
263	}
264
265	/// The length of the current codeword.
266	pub fn current_codeword_len(&self) -> usize {
267		match self.round_committed.last() {
268			Some((codeword, _)) => codeword.len(),
269			None => len_packed_slice(self.codeword),
270		}
271	}
272
273	fn is_commitment_round(&self) -> bool {
274		self.next_commit_round
275			.is_some_and(|round| round == self.curr_round)
276	}
277
278	/// Executes the next fold round and returns the folded codeword commitment.
279	///
280	/// As a memory efficient optimization, this method may not actually do the folding, but instead
281	/// accumulate the folding challenge for processing at a later time. This saves us from storing
282	/// intermediate folded codewords.
283	pub fn execute_fold_round(
284		&mut self,
285		challenge: F,
286	) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
287		self.unprocessed_challenges.push(challenge);
288		self.curr_round += 1;
289
290		if !self.is_commitment_round() {
291			return Ok(FoldRoundOutput::NoCommitment);
292		}
293
294		let dimensions_data = match self.round_committed.last() {
295			Some((codeword, _)) => FRIFoldData::new::<F, FA>(
296				log2_strict_usize(codeword.len()),
297				0,
298				self.unprocessed_challenges.len(),
299			),
300			None => FRIFoldData::new::<F, FA>(
301				self.params.rs_code().log_len(),
302				self.params.log_batch_size(),
303				self.unprocessed_challenges.len(),
304			),
305		};
306
307		let fri_fold_span = tracing::debug_span!(
308			"[task] FRI Fold",
309			phase = "piop_compiler",
310			perfetto_category = "task.main",
311			?dimensions_data
312		)
313		.entered();
314		// Fold the last codeword with the accumulated folding challenges.
315		let folded_codeword = match self.round_committed.last() {
316			Some((prev_codeword, _)) => {
317				// Fold a full codeword committed in the previous FRI round into a codeword with
318				// reduced dimension and rate.
319				fold_interleaved(
320					self.ntt,
321					prev_codeword,
322					&self.unprocessed_challenges,
323					log2_strict_usize(prev_codeword.len()),
324					0,
325				)
326			}
327			None => {
328				// Fold the interleaved codeword that was originally committed into a single
329				// codeword with the same or reduced block length, depending on the sequence of
330				// fold rounds.
331				fold_interleaved(
332					self.ntt,
333					self.codeword,
334					&self.unprocessed_challenges,
335					self.params.rs_code().log_len(),
336					self.params.log_batch_size(),
337				)
338			}
339		};
340		drop(fri_fold_span);
341		self.unprocessed_challenges.clear();
342
343		// take the first arity as coset_log_len, or use inv_rate if arities are empty
344		let coset_size = self
345			.params
346			.fold_arities()
347			.get(self.round_committed.len() + 1)
348			.map(|log| 1 << log)
349			.unwrap_or_else(|| 1 << self.params.n_final_challenges());
350		let dimension_data =
351			MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
352		let merkle_tree_span = tracing::debug_span!(
353			"[task] Merkle Tree",
354			phase = "piop_compiler",
355			perfetto_category = "task.main",
356			dimensions_data = ?dimension_data
357		)
358		.entered();
359		let (commitment, committed) = self
360			.merkle_prover
361			.commit(&folded_codeword, coset_size)
362			.map_err(|err| Error::VectorCommit(Box::new(err)))?;
363		drop(merkle_tree_span);
364
365		self.round_committed.push((folded_codeword, committed));
366
367		self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
368			let arity = self.params.fold_arities().get(self.round_committed.len())?;
369			Some(next_commit_round + arity)
370		});
371		Ok(FoldRoundOutput::Commitment(commitment.root))
372	}
373
374	/// Finalizes the FRI folding process.
375	///
376	/// This step will process any unprocessed folding challenges to produce the
377	/// final folded codeword. Then it will decode this final folded codeword
378	/// to get the final message. The result is the final message and a query prover instance.
379	///
380	/// This returns the final message and a query prover instance.
381	#[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
382	#[allow(clippy::type_complexity)]
383	pub fn finalize(
384		mut self,
385	) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
386		if self.curr_round != self.n_rounds() {
387			bail!(Error::EarlyProverFinish);
388		}
389
390		let terminate_codeword = self
391			.round_committed
392			.last()
393			.map(|(codeword, _)| codeword.clone())
394			.unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
395
396		self.unprocessed_challenges.clear();
397
398		let Self {
399			params,
400			codeword,
401			codeword_committed,
402			round_committed,
403			merkle_prover,
404			..
405		} = self;
406
407		let query_prover = FRIQueryProver {
408			params,
409			codeword,
410			codeword_committed,
411			round_committed,
412			merkle_prover,
413		};
414		Ok((terminate_codeword, query_prover))
415	}
416
417	pub fn finish_proof<Challenger_>(
418		self,
419		transcript: &mut ProverTranscript<Challenger_>,
420	) -> Result<(), Error>
421	where
422		Challenger_: Challenger,
423	{
424		let (terminate_codeword, query_prover) = self.finalize()?;
425		let mut advice = transcript.decommitment();
426		advice.write_scalar_slice(&terminate_codeword);
427
428		let layers = query_prover.vcs_optimal_layers()?;
429		for layer in layers {
430			advice.write_slice(&layer);
431		}
432
433		let params = query_prover.params;
434
435		for _ in 0..params.n_test_queries() {
436			let index = transcript.sample_bits(params.index_bits()) as usize;
437			query_prover.prove_query(index, transcript.decommitment())?;
438		}
439
440		Ok(())
441	}
442}
443
444/// A prover for the FRI query phase.
445pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
446where
447	F: BinaryField,
448	FA: BinaryField,
449	P: PackedField<Scalar = F>,
450	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
451	VCS: MerkleTreeScheme<F>,
452{
453	params: &'a FRIParams<F, FA>,
454	codeword: &'a [P],
455	codeword_committed: &'a MerkleProver::Committed,
456	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
457	merkle_prover: &'a MerkleProver,
458}
459
460impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
461where
462	F: TowerField + ExtensionField<FA>,
463	FA: BinaryField,
464	P: PackedField<Scalar = F>,
465	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
466	VCS: MerkleTreeScheme<F>,
467{
468	/// Number of oracles sent during the fold rounds.
469	pub fn n_oracles(&self) -> usize {
470		self.params.n_oracles()
471	}
472
473	/// Proves a FRI challenge query.
474	///
475	/// ## Arguments
476	///
477	/// * `index` - an index into the original codeword domain
478	#[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
479	pub fn prove_query<B>(
480		&self,
481		mut index: usize,
482		mut advice: TranscriptWriter<B>,
483	) -> Result<(), Error>
484	where
485		B: BufMut,
486	{
487		let mut arities_and_optimal_layers_depths = self
488			.params
489			.fold_arities()
490			.iter()
491			.copied()
492			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
493
494		let Some((first_fold_arity, first_optimal_layer_depth)) =
495			arities_and_optimal_layers_depths.next()
496		else {
497			// If there are no query proofs, that means that no oracles were sent during the FRI
498			// fold rounds. In that case, the original interleaved codeword is decommitted and
499			// the only checks that need to be performed are in `verify_last_oracle`.
500			return Ok(());
501		};
502
503		prove_coset_opening(
504			self.merkle_prover,
505			self.codeword,
506			self.codeword_committed,
507			index,
508			first_fold_arity,
509			first_optimal_layer_depth,
510			&mut advice,
511		)?;
512
513		for ((codeword, committed), (arity, optimal_layer_depth)) in
514			izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
515		{
516			index >>= arity;
517			prove_coset_opening(
518				self.merkle_prover,
519				codeword,
520				committed,
521				index,
522				arity,
523				optimal_layer_depth,
524				&mut advice,
525			)?;
526		}
527
528		Ok(())
529	}
530
531	pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
532		let committed_iter = std::iter::once(self.codeword_committed)
533			.chain(self.round_committed.iter().map(|(_, committed)| committed));
534
535		committed_iter
536			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
537			.map(|(committed, optimal_layer_depth)| {
538				self.merkle_prover
539					.layer(committed, optimal_layer_depth)
540					.map(|layer| layer.to_vec())
541					.map_err(|err| Error::VectorCommit(Box::new(err)))
542			})
543			.collect::<Result<Vec<_>, _>>()
544	}
545}
546
547fn prove_coset_opening<F, P, MTProver, B>(
548	merkle_prover: &MTProver,
549	codeword: &[P],
550	committed: &MTProver::Committed,
551	coset_index: usize,
552	log_coset_size: usize,
553	optimal_layer_depth: usize,
554	advice: &mut TranscriptWriter<B>,
555) -> Result<(), Error>
556where
557	F: TowerField,
558	P: PackedField<Scalar = F>,
559	MTProver: MerkleTreeProver<F>,
560	B: BufMut,
561{
562	let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
563		.take(1 << log_coset_size);
564	advice.write_scalar_iter(values);
565
566	merkle_prover
567		.prove_opening(committed, optimal_layer_depth, coset_index, advice)
568		.map_err(|err| Error::VectorCommit(Box::new(err)))?;
569
570	Ok(())
571}