binius_core/protocols/fri/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField};
4use binius_hal::{make_portable_backend, ComputationBackend};
5use binius_maybe_rayon::prelude::*;
6use binius_utils::{bail, SerializeBytes};
7use bytemuck::zeroed_vec;
8use bytes::BufMut;
9use itertools::izip;
10use tracing::instrument;
11
12use super::{
13	common::{vcs_optimal_layers_depths_iter, FRIParams},
14	error::Error,
15	TerminateCodeword,
16};
17use crate::{
18	fiat_shamir::{CanSampleBits, Challenger},
19	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
20	protocols::fri::common::{fold_chunk, fold_interleaved_chunk},
21	reed_solomon::reed_solomon::ReedSolomonCode,
22	transcript::{ProverTranscript, TranscriptWriter},
23};
24
25#[instrument(skip_all, level = "debug")]
26pub fn fold_codeword<F, FS>(
27	rs_code: &ReedSolomonCode<FS>,
28	codeword: &[F],
29	// Round is the number of total folding challenges received so far.
30	round: usize,
31	folding_challenges: &[F],
32) -> Vec<F>
33where
34	F: BinaryField + ExtensionField<FS>,
35	FS: BinaryField,
36{
37	// Preconditions
38	assert_eq!(codeword.len() % (1 << folding_challenges.len()), 0);
39	assert!(round >= folding_challenges.len());
40	assert!(round <= rs_code.log_dim());
41
42	if folding_challenges.is_empty() {
43		return codeword.to_vec();
44	}
45
46	let start_round = round - folding_challenges.len();
47	let chunk_size = 1 << folding_challenges.len();
48
49	// For each chunk of size `2^chunk_size` in the codeword, fold it with the folding challenges
50	codeword
51		.par_chunks(chunk_size)
52		.enumerate()
53		.map_init(
54			|| vec![F::default(); chunk_size],
55			|scratch_buffer, (chunk_index, chunk)| {
56				fold_chunk(
57					rs_code,
58					start_round,
59					chunk_index,
60					chunk,
61					folding_challenges,
62					scratch_buffer,
63				)
64			},
65		)
66		.collect()
67}
68
69/// Fold the interleaved codeword into a single codeword with the same block length.
70///
71/// ## Arguments
72///
73/// * `rs_code` - the Reed–Solomon code the protocol tests proximity to.
74/// * `codeword` - an interleaved codeword.
75/// * `challenges` - the folding challenges. The length must be at least `log_batch_size`.
76/// * `log_batch_size` - the base-2 logarithm of the batch size of the interleaved code.
77#[instrument(skip_all, level = "debug")]
78fn fold_interleaved<F, FS>(
79	rs_code: &ReedSolomonCode<FS>,
80	codeword: &[F],
81	challenges: &[F],
82	log_batch_size: usize,
83) -> Vec<F>
84where
85	F: BinaryField + ExtensionField<FS>,
86	FS: BinaryField,
87{
88	assert_eq!(codeword.len(), 1 << (rs_code.log_len() + log_batch_size));
89	assert!(challenges.len() >= log_batch_size);
90
91	let backend = make_portable_backend();
92
93	let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
94	let tensor = backend
95		.tensor_product_full_query(interleave_challenges)
96		.expect("number of challenges is less than 32");
97
98	// For each chunk of size `2^chunk_size` in the codeword, fold it with the folding challenges
99	let fold_chunk_size = 1 << fold_challenges.len();
100	let interleave_chunk_size = 1 << log_batch_size;
101	let chunk_size = fold_chunk_size * interleave_chunk_size;
102	codeword
103		.par_chunks(chunk_size)
104		.enumerate()
105		.map_init(
106			|| vec![F::default(); 2 * fold_chunk_size],
107			|scratch_buffer, (i, chunk)| {
108				fold_interleaved_chunk(
109					rs_code,
110					log_batch_size,
111					i,
112					chunk,
113					&tensor,
114					fold_challenges,
115					scratch_buffer,
116				)
117			},
118		)
119		.collect()
120}
121
122#[derive(Debug)]
123pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
124	pub commitment: VCSCommitment,
125	pub committed: VCSCommitted,
126	pub codeword: Vec<P>,
127}
128
129/// Creates a parallel iterator over scalars of subfield elementsAssumes chunk_size to be a power of two
130pub fn to_par_scalar_big_chunks<P>(
131	packed_slice: &[P],
132	chunk_size: usize,
133) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
134where
135	P: PackedField,
136{
137	packed_slice
138		.par_chunks(chunk_size / P::WIDTH)
139		.map(|chunk| PackedField::iter_slice(chunk))
140}
141
142pub fn to_par_scalar_small_chunks<P>(
143	packed_slice: &[P],
144	chunk_size: usize,
145) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
146where
147	P: PackedField,
148{
149	(0..packed_slice.len() * P::WIDTH)
150		.into_par_iter()
151		.step_by(chunk_size)
152		.map(move |start_index| {
153			let packed_item = &packed_slice[start_index / P::WIDTH];
154			packed_item
155				.iter()
156				.skip(start_index % P::WIDTH)
157				.take(chunk_size)
158		})
159}
160
161/// Encodes and commits the input message.
162///
163/// ## Arguments
164///
165/// * `rs_code` - the Reed-Solomon code to use for encoding
166/// * `params` - common FRI protocol parameters.
167/// * `merkle_prover` - the merke tree prover to use for committing
168/// * `message` - the interleaved message to encode and commit
169#[instrument(skip_all, level = "debug")]
170pub fn commit_interleaved<F, FA, P, PA, MerkleProver, VCS>(
171	rs_code: &ReedSolomonCode<PA>,
172	params: &FRIParams<F, FA>,
173	merkle_prover: &MerkleProver,
174	message: &[P],
175) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
176where
177	F: BinaryField,
178	FA: BinaryField,
179	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
180	PA: PackedField<Scalar = FA>,
181	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
182	VCS: MerkleTreeScheme<F>,
183{
184	let n_elems = rs_code.dim() << params.log_batch_size();
185	if message.len() * P::WIDTH != n_elems {
186		bail!(Error::InvalidArgs(
187			"interleaved message length does not match code parameters".to_string()
188		));
189	}
190
191	commit_interleaved_with(rs_code, params, merkle_prover, move |buffer| {
192		buffer.copy_from_slice(message)
193	})
194}
195
196/// Encodes and commits the input message with a closure for writing the message.
197///
198/// ## Arguments
199///
200/// * `rs_code` - the Reed-Solomon code to use for encoding
201/// * `params` - common FRI protocol parameters.
202/// * `merkle_prover` - the Merkle tree prover to use for committing
203/// * `message_writer` - a closure that writes the interleaved message to encode and commit
204#[instrument(skip_all, level = "debug")]
205pub fn commit_interleaved_with<F, FA, P, PA, MerkleProver, VCS>(
206	rs_code: &ReedSolomonCode<PA>,
207	params: &FRIParams<F, FA>,
208	merkle_prover: &MerkleProver,
209	message_writer: impl FnOnce(&mut [P]),
210) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
211where
212	F: BinaryField,
213	FA: BinaryField,
214	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
215	PA: PackedField<Scalar = FA>,
216	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
217	VCS: MerkleTreeScheme<F>,
218{
219	let log_batch_size = params.log_batch_size();
220	let log_elems = rs_code.log_dim() + log_batch_size;
221	if log_elems < P::LOG_WIDTH {
222		todo!("can't handle this case well");
223	}
224
225	let mut encoded = tracing::debug_span!("allocate codeword")
226		.in_scope(|| zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate())));
227	message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
228	rs_code.encode_ext_batch_inplace(&mut encoded, log_batch_size)?;
229
230	// take the first arity as coset_log_len, or use log_inv_rate if arities are empty
231	let coset_log_len = params
232		.fold_arities()
233		.first()
234		.copied()
235		.unwrap_or_else(|| rs_code.log_inv_rate());
236
237	let log_len = params.log_len() - coset_log_len;
238
239	let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
240		let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
241
242		merkle_prover
243			.commit_iterated(iterated_big_chunks, log_len)
244			.map_err(|err| Error::VectorCommit(Box::new(err)))?
245	} else {
246		let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
247
248		merkle_prover
249			.commit_iterated(iterated_small_chunks, log_len)
250			.map_err(|err| Error::VectorCommit(Box::new(err)))?
251	};
252
253	Ok(CommitOutput {
254		commitment: commitment.root,
255		committed: vcs_committed,
256		codeword: encoded,
257	})
258}
259
260pub enum FoldRoundOutput<VCSCommitment> {
261	NoCommitment,
262	Commitment(VCSCommitment),
263}
264
265/// A stateful prover for the FRI fold phase.
266pub struct FRIFolder<'a, F, FA, MerkleProver, VCS>
267where
268	FA: BinaryField,
269	F: BinaryField,
270	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
271	VCS: MerkleTreeScheme<F>,
272{
273	params: &'a FRIParams<F, FA>,
274	merkle_prover: &'a MerkleProver,
275	codeword: &'a [F],
276	codeword_committed: &'a MerkleProver::Committed,
277	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
278	curr_round: usize,
279	next_commit_round: Option<usize>,
280	unprocessed_challenges: Vec<F>,
281}
282
283impl<'a, F, FA, MerkleProver, VCS> FRIFolder<'a, F, FA, MerkleProver, VCS>
284where
285	F: TowerField + ExtensionField<FA>,
286	FA: BinaryField,
287	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
288	VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
289{
290	/// Constructs a new folder.
291	pub fn new(
292		params: &'a FRIParams<F, FA>,
293		merkle_prover: &'a MerkleProver,
294		committed_codeword: &'a [F],
295		committed: &'a MerkleProver::Committed,
296	) -> Result<Self, Error> {
297		if committed_codeword.len() != 1 << params.log_len() {
298			bail!(Error::InvalidArgs(
299				"Reed–Solomon code length must match interleaved codeword length".to_string(),
300			));
301		}
302
303		let next_commit_round = params.fold_arities().first().copied();
304		Ok(Self {
305			params,
306			merkle_prover,
307			codeword: committed_codeword,
308			codeword_committed: committed,
309			round_committed: Vec::with_capacity(params.n_oracles()),
310			curr_round: 0,
311			next_commit_round,
312			unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
313		})
314	}
315
316	/// Number of fold rounds, including the final fold.
317	pub const fn n_rounds(&self) -> usize {
318		self.params.n_fold_rounds()
319	}
320
321	/// Number of times `execute_fold_round` has been called.
322	pub const fn curr_round(&self) -> usize {
323		self.curr_round
324	}
325
326	fn is_commitment_round(&self) -> bool {
327		self.next_commit_round
328			.is_some_and(|round| round == self.curr_round)
329	}
330
331	/// Executes the next fold round and returns the folded codeword commitment.
332	///
333	/// As a memory efficient optimization, this method may not actually do the folding, but instead accumulate the
334	/// folding challenge for processing at a later time. This saves us from storing intermediate folded codewords.
335	#[instrument(skip_all, name = "fri::FRIFolder::execute_fold_round", level = "debug")]
336	pub fn execute_fold_round(
337		&mut self,
338		challenge: F,
339	) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
340		self.unprocessed_challenges.push(challenge);
341		self.curr_round += 1;
342
343		if !self.is_commitment_round() {
344			return Ok(FoldRoundOutput::NoCommitment);
345		}
346
347		// Fold the last codeword with the accumulated folding challenges.
348		let folded_codeword = match self.round_committed.last() {
349			Some((prev_codeword, _)) => {
350				// Fold a full codeword committed in the previous FRI round into a codeword with
351				// reduced dimension and rate.
352				fold_codeword(
353					self.params.rs_code(),
354					prev_codeword,
355					self.curr_round - self.params.log_batch_size(),
356					&self.unprocessed_challenges,
357				)
358			}
359			None => {
360				// Fold the interleaved codeword that was originally committed into a single
361				// codeword with the same or reduced block length, depending on the sequence of
362				// fold rounds.
363				fold_interleaved(
364					self.params.rs_code(),
365					self.codeword,
366					&self.unprocessed_challenges,
367					self.params.log_batch_size(),
368				)
369			}
370		};
371		self.unprocessed_challenges.clear();
372
373		// take the first arity as coset_log_len, or use inv_rate if arities are empty
374		let coset_size = self
375			.params
376			.fold_arities()
377			.get(self.round_committed.len() + 1)
378			.map(|log| 1 << log)
379			.unwrap_or_else(|| self.params.rs_code().inv_rate());
380
381		let (commitment, committed) = self
382			.merkle_prover
383			.commit(&folded_codeword, coset_size)
384			.map_err(|err| Error::VectorCommit(Box::new(err)))?;
385
386		self.round_committed.push((folded_codeword, committed));
387
388		self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
389			let arity = self.params.fold_arities().get(self.round_committed.len())?;
390			Some(next_commit_round + arity)
391		});
392		Ok(FoldRoundOutput::Commitment(commitment.root))
393	}
394
395	/// Finalizes the FRI folding process.
396	///
397	/// This step will process any unprocessed folding challenges to produce the
398	/// final folded codeword. Then it will decode this final folded codeword
399	/// to get the final message. The result is the final message and a query prover instance.
400	///
401	/// This returns the final message and a query prover instance.
402	#[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
403	#[allow(clippy::type_complexity)]
404	pub fn finalize(
405		mut self,
406	) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, MerkleProver, VCS>), Error> {
407		if self.curr_round != self.n_rounds() {
408			bail!(Error::EarlyProverFinish);
409		}
410
411		let terminate_codeword = self
412			.round_committed
413			.last()
414			.map(|(codeword, _)| codeword.clone())
415			.unwrap_or_else(|| self.codeword.to_vec());
416
417		self.unprocessed_challenges.clear();
418
419		let Self {
420			params,
421			codeword,
422			codeword_committed,
423			round_committed,
424			merkle_prover,
425			..
426		} = self;
427
428		let query_prover = FRIQueryProver {
429			params,
430			codeword,
431			codeword_committed,
432			round_committed,
433			merkle_prover,
434		};
435		Ok((terminate_codeword, query_prover))
436	}
437
438	pub fn finish_proof<Challenger_>(
439		self,
440		transcript: &mut ProverTranscript<Challenger_>,
441	) -> Result<(), Error>
442	where
443		Challenger_: Challenger,
444	{
445		let (terminate_codeword, query_prover) = self.finalize()?;
446		let mut advice = transcript.decommitment();
447		advice.write_scalar_slice(&terminate_codeword);
448
449		let layers = query_prover.vcs_optimal_layers()?;
450		for layer in layers {
451			advice.write_slice(&layer);
452		}
453
454		let params = query_prover.params;
455
456		for _ in 0..params.n_test_queries() {
457			let index = transcript.sample_bits(params.index_bits());
458			query_prover.prove_query(index, transcript.decommitment())?;
459		}
460
461		Ok(())
462	}
463}
464
465/// A prover for the FRI query phase.
466pub struct FRIQueryProver<'a, F, FA, MerkleProver, VCS>
467where
468	F: BinaryField,
469	FA: BinaryField,
470	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
471	VCS: MerkleTreeScheme<F>,
472{
473	params: &'a FRIParams<F, FA>,
474	codeword: &'a [F],
475	codeword_committed: &'a MerkleProver::Committed,
476	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
477	merkle_prover: &'a MerkleProver,
478}
479
480impl<F, FA, MerkleProver, VCS> FRIQueryProver<'_, F, FA, MerkleProver, VCS>
481where
482	F: TowerField + ExtensionField<FA>,
483	FA: BinaryField,
484	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
485	VCS: MerkleTreeScheme<F>,
486{
487	/// Number of oracles sent during the fold rounds.
488	pub fn n_oracles(&self) -> usize {
489		self.params.n_oracles()
490	}
491
492	/// Proves a FRI challenge query.
493	///
494	/// ## Arguments
495	///
496	/// * `index` - an index into the original codeword domain
497	#[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
498	pub fn prove_query<B>(
499		&self,
500		mut index: usize,
501		mut advice: TranscriptWriter<B>,
502	) -> Result<(), Error>
503	where
504		B: BufMut,
505	{
506		let mut arities_and_optimal_layers_depths = self
507			.params
508			.fold_arities()
509			.iter()
510			.copied()
511			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
512
513		let Some((first_fold_arity, first_optimal_layer_depth)) =
514			arities_and_optimal_layers_depths.next()
515		else {
516			// If there are no query proofs, that means that no oracles were sent during the FRI
517			// fold rounds. In that case, the original interleaved codeword is decommitted and
518			// the only checks that need to be performed are in `verify_last_oracle`.
519			return Ok(());
520		};
521
522		prove_coset_opening(
523			self.merkle_prover,
524			self.codeword,
525			self.codeword_committed,
526			index,
527			first_fold_arity,
528			first_optimal_layer_depth,
529			&mut advice,
530		)?;
531
532		for ((codeword, committed), (arity, optimal_layer_depth)) in
533			izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
534		{
535			index >>= arity;
536			prove_coset_opening(
537				self.merkle_prover,
538				codeword,
539				committed,
540				index,
541				arity,
542				optimal_layer_depth,
543				&mut advice,
544			)?;
545		}
546
547		Ok(())
548	}
549
550	pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
551		let committed_iter = std::iter::once(self.codeword_committed)
552			.chain(self.round_committed.iter().map(|(_, committed)| committed));
553
554		committed_iter
555			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
556			.map(|(committed, optimal_layer_depth)| {
557				self.merkle_prover
558					.layer(committed, optimal_layer_depth)
559					.map(|layer| layer.to_vec())
560					.map_err(|err| Error::VectorCommit(Box::new(err)))
561			})
562			.collect::<Result<Vec<_>, _>>()
563	}
564}
565
566fn prove_coset_opening<F, MTProver, B>(
567	merkle_prover: &MTProver,
568	codeword: &[F],
569	committed: &MTProver::Committed,
570	coset_index: usize,
571	log_coset_size: usize,
572	optimal_layer_depth: usize,
573	advice: &mut TranscriptWriter<B>,
574) -> Result<(), Error>
575where
576	F: TowerField,
577	MTProver: MerkleTreeProver<F>,
578	B: BufMut,
579{
580	let values = &codeword[(coset_index << log_coset_size)..((coset_index + 1) << log_coset_size)];
581	advice.write_scalar_slice(values);
582
583	merkle_prover
584		.prove_opening(committed, optimal_layer_depth, coset_index, advice)
585		.map_err(|err| Error::VectorCommit(Box::new(err)))?;
586
587	Ok(())
588}