binius_core/protocols/fri/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	packed::{iter_packed_slice_with_offset, len_packed_slice},
5	BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
6};
7use binius_math::MultilinearQuery;
8use binius_maybe_rayon::prelude::*;
9use binius_ntt::AdditiveNTT;
10use binius_utils::{bail, checked_arithmetics::log2_strict_usize, SerializeBytes};
11use bytemuck::zeroed_vec;
12use bytes::BufMut;
13use itertools::izip;
14use tracing::instrument;
15
16use super::{
17	common::{vcs_optimal_layers_depths_iter, FRIParams},
18	error::Error,
19	logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
20	TerminateCodeword,
21};
22use crate::{
23	fiat_shamir::{CanSampleBits, Challenger},
24	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
25	protocols::fri::{common::fold_interleaved_chunk, logging::FRIFoldData},
26	reed_solomon::reed_solomon::ReedSolomonCode,
27	transcript::{ProverTranscript, TranscriptWriter},
28};
29
30/// FRI-fold the interleaved codeword using the given challenges.
31///
32/// ## Arguments
33///
34/// * `ntt` - the NTT instance, used to look up the twiddle values.
35/// * `codeword` - an interleaved codeword.
36/// * `challenges` - the folding challenges. The length must be at least `log_batch_size`.
37/// * `log_len` - the binary logarithm of the code length.
38/// * `log_batch_size` - the binary logarithm of the interleaved code batch size.
39///
40/// See [DP24], Def. 3.6 and Lemma 3.9 for more details.
41///
42/// [DP24]: <https://eprint.iacr.org/2024/504>
43#[instrument(skip_all, level = "debug")]
44pub fn fold_interleaved<F, FS, NTT, P>(
45	ntt: &NTT,
46	codeword: &[P],
47	challenges: &[F],
48	log_len: usize,
49	log_batch_size: usize,
50) -> Vec<F>
51where
52	F: BinaryField + ExtensionField<FS>,
53	FS: BinaryField,
54	NTT: AdditiveNTT<FS> + Sync,
55	P: PackedField<Scalar = F>,
56{
57	assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH));
58	assert!(challenges.len() >= log_batch_size);
59
60	let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
61	let tensor = MultilinearQuery::expand(interleave_challenges);
62
63	// For each chunk of size `2^chunk_size` in the codeword, fold it with the folding challenges
64	let fold_chunk_size = 1 << fold_challenges.len();
65	let chunk_size = 1 << challenges.len().saturating_sub(P::LOG_WIDTH);
66	codeword
67		.par_chunks(chunk_size)
68		.enumerate()
69		.map_init(
70			|| vec![F::default(); fold_chunk_size],
71			|scratch_buffer, (i, chunk)| {
72				fold_interleaved_chunk(
73					ntt,
74					log_len,
75					log_batch_size,
76					i,
77					chunk,
78					tensor.expansion(),
79					fold_challenges,
80					scratch_buffer,
81				)
82			},
83		)
84		.collect()
85}
86
87#[derive(Debug)]
88pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
89	pub commitment: VCSCommitment,
90	pub committed: VCSCommitted,
91	pub codeword: Vec<P>,
92}
93
94/// Creates a parallel iterator over scalars of subfield elementsAssumes chunk_size to be a power of two
95pub fn to_par_scalar_big_chunks<P>(
96	packed_slice: &[P],
97	chunk_size: usize,
98) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
99where
100	P: PackedField,
101{
102	packed_slice
103		.par_chunks(chunk_size / P::WIDTH)
104		.map(|chunk| PackedField::iter_slice(chunk))
105}
106
107pub fn to_par_scalar_small_chunks<P>(
108	packed_slice: &[P],
109	chunk_size: usize,
110) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
111where
112	P: PackedField,
113{
114	(0..packed_slice.len() * P::WIDTH)
115		.into_par_iter()
116		.step_by(chunk_size)
117		.map(move |start_index| {
118			let packed_item = &packed_slice[start_index / P::WIDTH];
119			packed_item
120				.iter()
121				.skip(start_index % P::WIDTH)
122				.take(chunk_size)
123		})
124}
125
126/// Encodes and commits the input message.
127///
128/// ## Arguments
129///
130/// * `rs_code` - the Reed-Solomon code to use for encoding
131/// * `params` - common FRI protocol parameters.
132/// * `merkle_prover` - the merke tree prover to use for committing
133/// * `message` - the interleaved message to encode and commit
134#[instrument(skip_all, level = "debug")]
135pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
136	rs_code: &ReedSolomonCode<FA>,
137	params: &FRIParams<F, FA>,
138	ntt: &NTT,
139	merkle_prover: &MerkleProver,
140	message: &[P],
141) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
142where
143	F: BinaryField,
144	FA: BinaryField,
145	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
146	PA: PackedField<Scalar = FA>,
147	NTT: AdditiveNTT<FA> + Sync,
148	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
149	VCS: MerkleTreeScheme<F>,
150{
151	let n_elems = rs_code.dim() << params.log_batch_size();
152	if message.len() * P::WIDTH != n_elems {
153		bail!(Error::InvalidArgs(
154			"interleaved message length does not match code parameters".to_string()
155		));
156	}
157
158	commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
159		buffer.copy_from_slice(message)
160	})
161}
162
163/// Encodes and commits the input message with a closure for writing the message.
164///
165/// ## Arguments
166///
167/// * `rs_code` - the Reed-Solomon code to use for encoding
168/// * `params` - common FRI protocol parameters.
169/// * `merkle_prover` - the Merkle tree prover to use for committing
170/// * `message_writer` - a closure that writes the interleaved message to encode and commit
171pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
172	params: &FRIParams<F, FA>,
173	ntt: &NTT,
174	merkle_prover: &MerkleProver,
175	message_writer: impl FnOnce(&mut [P]),
176) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
177where
178	F: BinaryField,
179	FA: BinaryField,
180	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
181	PA: PackedField<Scalar = FA>,
182	NTT: AdditiveNTT<FA> + Sync,
183	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
184	VCS: MerkleTreeScheme<F>,
185{
186	let rs_code = params.rs_code();
187	let log_batch_size = params.log_batch_size();
188	let log_elems = rs_code.log_dim() + log_batch_size;
189	if log_elems < P::LOG_WIDTH {
190		todo!("can't handle this case well");
191	}
192
193	let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
194
195	let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
196	tracing::debug_span!("[task] Sort & Merge", phase = "commit", perfetto_category = "task.main", dimensions_data = ?dimensions_data)
197		.in_scope(|| {
198			message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
199		});
200
201	let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
202	tracing::debug_span!("[task] RS Encode", phase = "commit", perfetto_category = "task.main", dimensions_data = ?dimensions_data)
203		.in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
204
205	// Take the first arity as coset_log_len, or use the value such that the number of leaves equals 1 << log_inv_rate if arities is empty
206	let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
207
208	let log_len = params.log_len() - coset_log_len;
209	let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
210	let merkle_tree_span = tracing::debug_span!(
211		"[task] Merkle Tree",
212		phase = "commit",
213		perfetto_category = "task.main",
214		dimensions_data = ?dimension_data
215	)
216	.entered();
217	let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
218		let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
219
220		merkle_prover
221			.commit_iterated(iterated_big_chunks, log_len)
222			.map_err(|err| Error::VectorCommit(Box::new(err)))?
223	} else {
224		let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
225
226		merkle_prover
227			.commit_iterated(iterated_small_chunks, log_len)
228			.map_err(|err| Error::VectorCommit(Box::new(err)))?
229	};
230	drop(merkle_tree_span);
231
232	Ok(CommitOutput {
233		commitment: commitment.root,
234		committed: vcs_committed,
235		codeword: encoded,
236	})
237}
238
239pub enum FoldRoundOutput<VCSCommitment> {
240	NoCommitment,
241	Commitment(VCSCommitment),
242}
243
244/// A stateful prover for the FRI fold phase.
245pub struct FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
246where
247	FA: BinaryField,
248	F: BinaryField,
249	P: PackedField<Scalar = F>,
250	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
251	VCS: MerkleTreeScheme<F>,
252{
253	params: &'a FRIParams<F, FA>,
254	ntt: &'a NTT,
255	merkle_prover: &'a MerkleProver,
256	codeword: &'a [P],
257	codeword_committed: &'a MerkleProver::Committed,
258	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
259	curr_round: usize,
260	next_commit_round: Option<usize>,
261	unprocessed_challenges: Vec<F>,
262}
263
264impl<'a, F, FA, P, NTT, MerkleProver, VCS> FRIFolder<'a, F, FA, P, NTT, MerkleProver, VCS>
265where
266	F: TowerField + ExtensionField<FA>,
267	FA: BinaryField,
268	P: PackedField<Scalar = F>,
269	NTT: AdditiveNTT<FA> + Sync,
270	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
271	VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
272{
273	/// Constructs a new folder.
274	pub fn new(
275		params: &'a FRIParams<F, FA>,
276		ntt: &'a NTT,
277		merkle_prover: &'a MerkleProver,
278		committed_codeword: &'a [P],
279		committed: &'a MerkleProver::Committed,
280	) -> Result<Self, Error> {
281		if len_packed_slice(committed_codeword) < 1 << params.log_len() {
282			bail!(Error::InvalidArgs(
283				"Reed–Solomon code length must match interleaved codeword length".to_string(),
284			));
285		}
286
287		let next_commit_round = params.fold_arities().first().copied();
288		Ok(Self {
289			params,
290			ntt,
291			merkle_prover,
292			codeword: committed_codeword,
293			codeword_committed: committed,
294			round_committed: Vec::with_capacity(params.n_oracles()),
295			curr_round: 0,
296			next_commit_round,
297			unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
298		})
299	}
300
301	/// Number of fold rounds, including the final fold.
302	pub const fn n_rounds(&self) -> usize {
303		self.params.n_fold_rounds()
304	}
305
306	/// Number of times `execute_fold_round` has been called.
307	pub const fn curr_round(&self) -> usize {
308		self.curr_round
309	}
310
311	/// The length of the current codeword.
312	pub fn current_codeword_len(&self) -> usize {
313		match self.round_committed.last() {
314			Some((codeword, _)) => codeword.len(),
315			None => len_packed_slice(self.codeword),
316		}
317	}
318
319	fn is_commitment_round(&self) -> bool {
320		self.next_commit_round
321			.is_some_and(|round| round == self.curr_round)
322	}
323
324	/// Executes the next fold round and returns the folded codeword commitment.
325	///
326	/// As a memory efficient optimization, this method may not actually do the folding, but instead accumulate the
327	/// folding challenge for processing at a later time. This saves us from storing intermediate folded codewords.
328	pub fn execute_fold_round(
329		&mut self,
330		challenge: F,
331	) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
332		self.unprocessed_challenges.push(challenge);
333		self.curr_round += 1;
334
335		if !self.is_commitment_round() {
336			return Ok(FoldRoundOutput::NoCommitment);
337		}
338
339		let dimensions_data = match self.round_committed.last() {
340			Some((codeword, _)) => FRIFoldData::new(
341				log2_strict_usize(codeword.len()),
342				0,
343				self.unprocessed_challenges.len(),
344			),
345			None => FRIFoldData::new(
346				self.params.rs_code().log_len(),
347				self.params.log_batch_size(),
348				self.unprocessed_challenges.len(),
349			),
350		};
351
352		let fri_fold_span = tracing::debug_span!(
353			"[task] FRI Fold",
354			phase = "piop_compiler",
355			perfetto_category = "task.main",
356			dimensions_data = ?dimensions_data
357		)
358		.entered();
359		// Fold the last codeword with the accumulated folding challenges.
360		let folded_codeword = match self.round_committed.last() {
361			Some((prev_codeword, _)) => {
362				// Fold a full codeword committed in the previous FRI round into a codeword with
363				// reduced dimension and rate.
364				fold_interleaved(
365					self.ntt,
366					prev_codeword,
367					&self.unprocessed_challenges,
368					log2_strict_usize(prev_codeword.len()),
369					0,
370				)
371			}
372			None => {
373				// Fold the interleaved codeword that was originally committed into a single
374				// codeword with the same or reduced block length, depending on the sequence of
375				// fold rounds.
376				fold_interleaved(
377					self.ntt,
378					self.codeword,
379					&self.unprocessed_challenges,
380					self.params.rs_code().log_len(),
381					self.params.log_batch_size(),
382				)
383			}
384		};
385		drop(fri_fold_span);
386		self.unprocessed_challenges.clear();
387
388		// take the first arity as coset_log_len, or use inv_rate if arities are empty
389		let coset_size = self
390			.params
391			.fold_arities()
392			.get(self.round_committed.len() + 1)
393			.map(|log| 1 << log)
394			.unwrap_or_else(|| 1 << self.params.n_final_challenges());
395		let dimension_data =
396			MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
397		let merkle_tree_span = tracing::debug_span!(
398			"[task] Merkle Tree",
399			phase = "piop_compiler",
400			perfetto_category = "task.main",
401			dimensions_data = ?dimension_data
402		)
403		.entered();
404		let (commitment, committed) = self
405			.merkle_prover
406			.commit(&folded_codeword, coset_size)
407			.map_err(|err| Error::VectorCommit(Box::new(err)))?;
408		drop(merkle_tree_span);
409
410		self.round_committed.push((folded_codeword, committed));
411
412		self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
413			let arity = self.params.fold_arities().get(self.round_committed.len())?;
414			Some(next_commit_round + arity)
415		});
416		Ok(FoldRoundOutput::Commitment(commitment.root))
417	}
418
419	/// Finalizes the FRI folding process.
420	///
421	/// This step will process any unprocessed folding challenges to produce the
422	/// final folded codeword. Then it will decode this final folded codeword
423	/// to get the final message. The result is the final message and a query prover instance.
424	///
425	/// This returns the final message and a query prover instance.
426	#[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
427	#[allow(clippy::type_complexity)]
428	pub fn finalize(
429		mut self,
430	) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
431		if self.curr_round != self.n_rounds() {
432			bail!(Error::EarlyProverFinish);
433		}
434
435		let terminate_codeword = self
436			.round_committed
437			.last()
438			.map(|(codeword, _)| codeword.clone())
439			.unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
440
441		self.unprocessed_challenges.clear();
442
443		let Self {
444			params,
445			codeword,
446			codeword_committed,
447			round_committed,
448			merkle_prover,
449			..
450		} = self;
451
452		let query_prover = FRIQueryProver {
453			params,
454			codeword,
455			codeword_committed,
456			round_committed,
457			merkle_prover,
458		};
459		Ok((terminate_codeword, query_prover))
460	}
461
462	pub fn finish_proof<Challenger_>(
463		self,
464		transcript: &mut ProverTranscript<Challenger_>,
465	) -> Result<(), Error>
466	where
467		Challenger_: Challenger,
468	{
469		let (terminate_codeword, query_prover) = self.finalize()?;
470		let mut advice = transcript.decommitment();
471		advice.write_scalar_slice(&terminate_codeword);
472
473		let layers = query_prover.vcs_optimal_layers()?;
474		for layer in layers {
475			advice.write_slice(&layer);
476		}
477
478		let params = query_prover.params;
479
480		for _ in 0..params.n_test_queries() {
481			let index = transcript.sample_bits(params.index_bits()) as usize;
482			query_prover.prove_query(index, transcript.decommitment())?;
483		}
484
485		Ok(())
486	}
487}
488
489/// A prover for the FRI query phase.
490pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
491where
492	F: BinaryField,
493	FA: BinaryField,
494	P: PackedField<Scalar = F>,
495	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
496	VCS: MerkleTreeScheme<F>,
497{
498	params: &'a FRIParams<F, FA>,
499	codeword: &'a [P],
500	codeword_committed: &'a MerkleProver::Committed,
501	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
502	merkle_prover: &'a MerkleProver,
503}
504
505impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
506where
507	F: TowerField + ExtensionField<FA>,
508	FA: BinaryField,
509	P: PackedField<Scalar = F>,
510	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
511	VCS: MerkleTreeScheme<F>,
512{
513	/// Number of oracles sent during the fold rounds.
514	pub fn n_oracles(&self) -> usize {
515		self.params.n_oracles()
516	}
517
518	/// Proves a FRI challenge query.
519	///
520	/// ## Arguments
521	///
522	/// * `index` - an index into the original codeword domain
523	#[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
524	pub fn prove_query<B>(
525		&self,
526		mut index: usize,
527		mut advice: TranscriptWriter<B>,
528	) -> Result<(), Error>
529	where
530		B: BufMut,
531	{
532		let mut arities_and_optimal_layers_depths = self
533			.params
534			.fold_arities()
535			.iter()
536			.copied()
537			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
538
539		let Some((first_fold_arity, first_optimal_layer_depth)) =
540			arities_and_optimal_layers_depths.next()
541		else {
542			// If there are no query proofs, that means that no oracles were sent during the FRI
543			// fold rounds. In that case, the original interleaved codeword is decommitted and
544			// the only checks that need to be performed are in `verify_last_oracle`.
545			return Ok(());
546		};
547
548		prove_coset_opening(
549			self.merkle_prover,
550			self.codeword,
551			self.codeword_committed,
552			index,
553			first_fold_arity,
554			first_optimal_layer_depth,
555			&mut advice,
556		)?;
557
558		for ((codeword, committed), (arity, optimal_layer_depth)) in
559			izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
560		{
561			index >>= arity;
562			prove_coset_opening(
563				self.merkle_prover,
564				codeword,
565				committed,
566				index,
567				arity,
568				optimal_layer_depth,
569				&mut advice,
570			)?;
571		}
572
573		Ok(())
574	}
575
576	pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
577		let committed_iter = std::iter::once(self.codeword_committed)
578			.chain(self.round_committed.iter().map(|(_, committed)| committed));
579
580		committed_iter
581			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
582			.map(|(committed, optimal_layer_depth)| {
583				self.merkle_prover
584					.layer(committed, optimal_layer_depth)
585					.map(|layer| layer.to_vec())
586					.map_err(|err| Error::VectorCommit(Box::new(err)))
587			})
588			.collect::<Result<Vec<_>, _>>()
589	}
590}
591
592fn prove_coset_opening<F, P, MTProver, B>(
593	merkle_prover: &MTProver,
594	codeword: &[P],
595	committed: &MTProver::Committed,
596	coset_index: usize,
597	log_coset_size: usize,
598	optimal_layer_depth: usize,
599	advice: &mut TranscriptWriter<B>,
600) -> Result<(), Error>
601where
602	F: TowerField,
603	P: PackedField<Scalar = F>,
604	MTProver: MerkleTreeProver<F>,
605	B: BufMut,
606{
607	let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
608		.take(1 << log_coset_size);
609	advice.write_scalar_iter(values);
610
611	merkle_prover
612		.prove_opening(committed, optimal_layer_depth, coset_index, advice)
613		.map_err(|err| Error::VectorCommit(Box::new(err)))?;
614
615	Ok(())
616}