binius_core/protocols/fri/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2use binius_compute::{
3	ComputeLayerExecutor,
4	alloc::ComputeAllocator,
5	layer::ComputeLayer,
6	memory::{ComputeMemory, SizedSlice},
7};
8use binius_field::{
9	BinaryField, ExtensionField, PackedExtension, PackedField, TowerField,
10	packed::{iter_packed_slice_with_offset, len_packed_slice},
11	unpack_if_possible,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_ntt::AdditiveNTT;
15use binius_utils::{SerializeBytes, bail, checked_arithmetics::log2_strict_usize};
16use bytemuck::zeroed_vec;
17use bytes::BufMut;
18use itertools::izip;
19use tracing::instrument;
20
21use super::{
22	TerminateCodeword,
23	common::{FRIParams, vcs_optimal_layers_depths_iter},
24	error::Error,
25	logging::{MerkleTreeDimensionData, RSEncodeDimensionData, SortAndMergeDimensionData},
26};
27use crate::{
28	fiat_shamir::{CanSampleBits, Challenger},
29	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
30	protocols::fri::logging::FRIFoldData,
31	reed_solomon::reed_solomon::ReedSolomonCode,
32	transcript::{ProverTranscript, TranscriptWriter},
33};
34
35#[derive(Debug)]
36pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
37	pub commitment: VCSCommitment,
38	pub committed: VCSCommitted,
39	pub codeword: Vec<P>,
40}
41
42/// Creates a parallel iterator over scalars of subfield elementsAssumes chunk_size to be a power of
43/// two
44pub fn to_par_scalar_big_chunks<P>(
45	packed_slice: &[P],
46	chunk_size: usize,
47) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
48where
49	P: PackedField,
50{
51	packed_slice
52		.par_chunks(chunk_size / P::WIDTH)
53		.map(|chunk| PackedField::iter_slice(chunk))
54}
55
56pub fn to_par_scalar_small_chunks<P>(
57	packed_slice: &[P],
58	chunk_size: usize,
59) -> impl IndexedParallelIterator<Item: Iterator<Item = P::Scalar> + Send + '_>
60where
61	P: PackedField,
62{
63	(0..packed_slice.len() * P::WIDTH)
64		.into_par_iter()
65		.step_by(chunk_size)
66		.map(move |start_index| {
67			let packed_item = &packed_slice[start_index / P::WIDTH];
68			packed_item
69				.iter()
70				.skip(start_index % P::WIDTH)
71				.take(chunk_size)
72		})
73}
74
75/// Encodes and commits the input message.
76///
77/// ## Arguments
78///
79/// * `rs_code` - the Reed-Solomon code to use for encoding
80/// * `params` - common FRI protocol parameters.
81/// * `merkle_prover` - the merke tree prover to use for committing
82/// * `message` - the interleaved message to encode and commit
83#[instrument(skip_all, level = "debug")]
84pub fn commit_interleaved<F, FA, P, PA, NTT, MerkleProver, VCS>(
85	rs_code: &ReedSolomonCode<FA>,
86	params: &FRIParams<F, FA>,
87	ntt: &NTT,
88	merkle_prover: &MerkleProver,
89	message: &[P],
90) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
91where
92	F: BinaryField,
93	FA: BinaryField,
94	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
95	PA: PackedField<Scalar = FA>,
96	NTT: AdditiveNTT<FA> + Sync,
97	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
98	VCS: MerkleTreeScheme<F>,
99{
100	let n_elems = rs_code.dim() << params.log_batch_size();
101	if message.len() * P::WIDTH != n_elems {
102		bail!(Error::InvalidArgs(
103			"interleaved message length does not match code parameters".to_string()
104		));
105	}
106
107	commit_interleaved_with(params, ntt, merkle_prover, move |buffer| {
108		buffer.copy_from_slice(message)
109	})
110}
111
112/// Encodes and commits the input message with a closure for writing the message.
113///
114/// ## Arguments
115///
116/// * `rs_code` - the Reed-Solomon code to use for encoding
117/// * `params` - common FRI protocol parameters.
118/// * `merkle_prover` - the Merkle tree prover to use for committing
119/// * `message_writer` - a closure that writes the interleaved message to encode and commit
120pub fn commit_interleaved_with<F, FA, P, PA, NTT, MerkleProver, VCS>(
121	params: &FRIParams<F, FA>,
122	ntt: &NTT,
123	merkle_prover: &MerkleProver,
124	message_writer: impl FnOnce(&mut [P]),
125) -> Result<CommitOutput<P, VCS::Digest, MerkleProver::Committed>, Error>
126where
127	F: BinaryField,
128	FA: BinaryField,
129	P: PackedField<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
130	PA: PackedField<Scalar = FA>,
131	NTT: AdditiveNTT<FA> + Sync,
132	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
133	VCS: MerkleTreeScheme<F>,
134{
135	let rs_code = params.rs_code();
136	let log_batch_size = params.log_batch_size();
137	let log_elems = rs_code.log_dim() + log_batch_size;
138	if log_elems < P::LOG_WIDTH {
139		todo!("can't handle this case well");
140	}
141
142	let mut encoded = zeroed_vec(1 << (log_elems - P::LOG_WIDTH + rs_code.log_inv_rate()));
143
144	let dimensions_data = SortAndMergeDimensionData::new::<F>(log_elems);
145	tracing::debug_span!(
146		"[task] Sort & Merge",
147		phase = "commit",
148		perfetto_category = "task.main",
149		?dimensions_data
150	)
151	.in_scope(|| {
152		message_writer(&mut encoded[..1 << (log_elems - P::LOG_WIDTH)]);
153	});
154
155	let dimensions_data = RSEncodeDimensionData::new::<F>(log_elems, log_batch_size);
156	tracing::debug_span!(
157		"[task] RS Encode",
158		phase = "commit",
159		perfetto_category = "task.main",
160		?dimensions_data
161	)
162	.in_scope(|| rs_code.encode_ext_batch_inplace(ntt, &mut encoded, log_batch_size))?;
163
164	// Take the first arity as coset_log_len, or use the value such that the number of leaves equals
165	// 1 << log_inv_rate if arities is empty
166	let coset_log_len = params.fold_arities().first().copied().unwrap_or(log_elems);
167
168	let log_len = params.log_len() - coset_log_len;
169	let dimension_data = MerkleTreeDimensionData::new::<F>(log_len, 1 << coset_log_len);
170	let merkle_tree_span = tracing::debug_span!(
171		"[task] Merkle Tree",
172		phase = "commit",
173		perfetto_category = "task.main",
174		dimensions_data = ?dimension_data
175	)
176	.entered();
177	let (commitment, vcs_committed) = if coset_log_len > P::LOG_WIDTH {
178		let iterated_big_chunks = to_par_scalar_big_chunks(&encoded, 1 << coset_log_len);
179
180		merkle_prover
181			.commit_iterated(iterated_big_chunks, log_len)
182			.map_err(|err| Error::VectorCommit(Box::new(err)))?
183	} else {
184		let iterated_small_chunks = to_par_scalar_small_chunks(&encoded, 1 << coset_log_len);
185
186		merkle_prover
187			.commit_iterated(iterated_small_chunks, log_len)
188			.map_err(|err| Error::VectorCommit(Box::new(err)))?
189	};
190	drop(merkle_tree_span);
191
192	Ok(CommitOutput {
193		commitment: commitment.root,
194		committed: vcs_committed,
195		codeword: encoded,
196	})
197}
198
199pub enum FoldRoundOutput<VCSCommitment> {
200	NoCommitment,
201	Commitment(VCSCommitment),
202}
203
204/// Represents the committed data for a single round in the FRI protocol when using a compute layer
205pub struct SingleRoundCommitted<'b, F, Hal, MerkleProver>
206where
207	F: BinaryField,
208	Hal: ComputeLayer<F>,
209	MerkleProver: MerkleTreeProver<F>,
210{
211	/// The folded codeword on the host
212	pub host_codeword: Vec<F>,
213	/// The folded codeword on the device
214	pub device_codeword: <Hal::DevMem as ComputeMemory<F>>::FSliceMut<'b>,
215	/// The Merkle tree commitment
216	pub committed: MerkleProver::Committed,
217}
218
219pub struct FRIFolder<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
220where
221	FA: BinaryField,
222	F: BinaryField,
223	P: PackedField<Scalar = F>,
224	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
225	VCS: MerkleTreeScheme<F>,
226	Hal: ComputeLayer<F>,
227{
228	params: &'a FRIParams<F, FA>,
229	ntt: &'a NTT,
230	merkle_prover: &'a MerkleProver,
231	codeword: &'a [P],
232	codeword_committed: &'a MerkleProver::Committed,
233	round_committed: Vec<SingleRoundCommitted<'b, F, Hal, MerkleProver>>,
234	curr_round: usize,
235	next_commit_round: Option<usize>,
236	unprocessed_challenges: Vec<F>,
237	cl: &'b Hal,
238}
239
240impl<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
241	FRIFolder<'a, 'b, F, FA, P, NTT, MerkleProver, VCS, Hal>
242where
243	F: TowerField + ExtensionField<FA>,
244	FA: BinaryField,
245	P: PackedField<Scalar = F>,
246	NTT: AdditiveNTT<FA> + Sync,
247	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
248	VCS: MerkleTreeScheme<F, Digest: SerializeBytes>,
249	Hal: ComputeLayer<F>,
250{
251	/// Constructs a new folder.
252	pub fn new(
253		hal: &'b Hal,
254		params: &'a FRIParams<F, FA>,
255		ntt: &'a NTT,
256		merkle_prover: &'a MerkleProver,
257		codeword: &'a [P],
258		committed: &'a MerkleProver::Committed,
259	) -> Result<Self, Error> {
260		if len_packed_slice(codeword) < 1 << params.log_len() {
261			bail!(Error::InvalidArgs(
262				"Reed–Solomon code length must match interleaved codeword length".to_string(),
263			));
264		}
265
266		let next_commit_round = params.fold_arities().first().copied();
267		Ok(Self {
268			params,
269			ntt,
270			merkle_prover,
271			codeword,
272			codeword_committed: committed,
273			round_committed: Vec::with_capacity(params.n_oracles()),
274			curr_round: 0,
275			next_commit_round,
276			unprocessed_challenges: Vec::with_capacity(params.rs_code().log_dim()),
277			cl: hal,
278		})
279	}
280
281	/// Number of fold rounds, including the final fold.
282	pub const fn n_rounds(&self) -> usize {
283		self.params.n_fold_rounds()
284	}
285
286	/// Number of times `execute_fold_round` has been called.
287	pub const fn curr_round(&self) -> usize {
288		self.curr_round
289	}
290
291	/// The length of the current codeword.
292	pub fn current_codeword_len(&self) -> usize {
293		match self.round_committed.last() {
294			Some(round) => round.host_codeword.len(),
295			None => len_packed_slice(self.codeword),
296		}
297	}
298
299	fn is_commitment_round(&self) -> bool {
300		self.next_commit_round
301			.is_some_and(|round| round == self.curr_round)
302	}
303
304	/// Executes the next fold round and returns the folded codeword commitment.
305	///
306	/// As a memory efficient optimization, this method may not actually do the folding, but instead
307	/// accumulate the folding challenge for processing at a later time. This saves us from storing
308	/// intermediate folded codewords.
309	pub fn execute_fold_round(
310		&mut self,
311		allocator: &'b impl ComputeAllocator<F, Hal::DevMem>,
312		challenge: F,
313	) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
314		self.unprocessed_challenges.push(challenge);
315		self.curr_round += 1;
316
317		if !self.is_commitment_round() {
318			return Ok(FoldRoundOutput::NoCommitment);
319		}
320
321		let dimensions_data = match self.round_committed.last() {
322			Some(round) => FRIFoldData::new::<F, FA>(
323				log2_strict_usize(round.host_codeword.len()),
324				0,
325				self.unprocessed_challenges.len(),
326			),
327			None => FRIFoldData::new::<F, FA>(
328				self.params.rs_code().log_len(),
329				self.params.log_batch_size(),
330				self.unprocessed_challenges.len(),
331			),
332		};
333
334		let fri_fold_span = tracing::debug_span!(
335			"[task] FRI Fold",
336			phase = "piop_compiler",
337			perfetto_category = "task.main",
338			?dimensions_data
339		)
340		.entered();
341		// Fold the last codeword with the accumulated folding challenges.
342		let folded_codeword = match self.round_committed.last() {
343			Some(prev_round) => {
344				// Fold a full codeword committed in the previous FRI round into a codeword with
345				// reduced dimension and rate.
346				let mut folded_codeword = allocator.alloc(
347					prev_round.device_codeword.len() / (1 << self.unprocessed_challenges.len()),
348				)?;
349
350				self.cl.execute(|exec| {
351					exec.fri_fold(
352						self.ntt,
353						log2_strict_usize(prev_round.device_codeword.len()),
354						0,
355						&self.unprocessed_challenges,
356						Hal::DevMem::as_const(&prev_round.device_codeword),
357						&mut folded_codeword,
358					)?;
359
360					Ok(vec![])
361				})?;
362
363				folded_codeword
364			}
365			None => {
366				let codeword_len = len_packed_slice(self.codeword);
367				let mut original_codeword = allocator.alloc(codeword_len)?;
368				unpack_if_possible(
369					self.codeword,
370					|scalars| self.cl.copy_h2d(scalars, &mut original_codeword),
371					|_packed| unimplemented!("non-dense packed fields not supported"),
372				)?;
373				let mut folded_codeword = allocator.alloc(
374					1 << (self.params.rs_code().log_len()
375						- (self.unprocessed_challenges.len() - self.params.log_batch_size())),
376				)?;
377				self.cl.execute(|exec| {
378					exec.fri_fold(
379						self.ntt,
380						self.params.rs_code().log_len(),
381						self.params.log_batch_size(),
382						&self.unprocessed_challenges,
383						Hal::DevMem::as_const(&original_codeword),
384						&mut folded_codeword,
385					)?;
386
387					Ok(vec![])
388				})?;
389
390				folded_codeword
391			}
392		};
393		drop(fri_fold_span);
394		self.unprocessed_challenges.clear();
395
396		let mut folded_codeword_host = zeroed_vec(folded_codeword.len());
397		self.cl
398			.copy_d2h(Hal::DevMem::as_const(&folded_codeword), &mut folded_codeword_host)?;
399
400		// take the first arity as coset_log_len, or use inv_rate if arities are empty
401		let coset_size = self
402			.params
403			.fold_arities()
404			.get(self.round_committed.len() + 1)
405			.map(|log| 1 << log)
406			.unwrap_or_else(|| 1 << self.params.n_final_challenges());
407		let dimension_data =
408			MerkleTreeDimensionData::new::<F>(dimensions_data.log_len(), coset_size);
409		let merkle_tree_span = tracing::debug_span!(
410			"[task] Merkle Tree",
411			phase = "piop_compiler",
412			perfetto_category = "task.main",
413			dimensions_data = ?dimension_data
414		)
415		.entered();
416		let (commitment, committed) = self
417			.merkle_prover
418			.commit(&folded_codeword_host, coset_size)
419			.map_err(|err| Error::VectorCommit(Box::new(err)))?;
420		drop(merkle_tree_span);
421
422		self.round_committed.push(SingleRoundCommitted {
423			host_codeword: folded_codeword_host,
424			device_codeword: folded_codeword,
425			committed,
426		});
427
428		self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
429			let arity = self.params.fold_arities().get(self.round_committed.len())?;
430			Some(next_commit_round + arity)
431		});
432		Ok(FoldRoundOutput::Commitment(commitment.root))
433	}
434
435	/// Finalizes the FRI folding process.
436	///
437	/// This step will process any unprocessed folding challenges to produce the
438	/// final folded codeword. Then it will decode this final folded codeword
439	/// to get the final message. The result is the final message and a query prover instance.
440	///
441	/// This returns the final message and a query prover instance.
442	#[instrument(skip_all, name = "fri::FRIFolder::finalize", level = "debug")]
443	#[allow(clippy::type_complexity)]
444	pub fn finalize(
445		mut self,
446	) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
447		if self.curr_round != self.n_rounds() {
448			bail!(Error::EarlyProverFinish);
449		}
450
451		let terminate_codeword = self
452			.round_committed
453			.last()
454			.map(|round| round.host_codeword.clone())
455			.unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
456
457		self.unprocessed_challenges.clear();
458
459		let Self {
460			params,
461			codeword,
462			codeword_committed,
463			round_committed,
464			merkle_prover,
465			..
466		} = self;
467
468		let round_committed = round_committed
469			.into_iter()
470			.map(|round| (round.host_codeword, round.committed))
471			.collect();
472
473		let query_prover = FRIQueryProver {
474			params,
475			codeword,
476			codeword_committed,
477			round_committed,
478			merkle_prover,
479		};
480		Ok((terminate_codeword, query_prover))
481	}
482
483	pub fn finish_proof<Challenger_>(
484		self,
485		transcript: &mut ProverTranscript<Challenger_>,
486	) -> Result<(), Error>
487	where
488		Challenger_: Challenger,
489	{
490		let (terminate_codeword, query_prover) = self.finalize()?;
491		let mut advice = transcript.decommitment();
492		advice.write_scalar_slice(&terminate_codeword);
493
494		let layers = query_prover.vcs_optimal_layers()?;
495		for layer in layers {
496			advice.write_slice(&layer);
497		}
498
499		let params = query_prover.params;
500
501		for _ in 0..params.n_test_queries() {
502			let index = transcript.sample_bits(params.index_bits()) as usize;
503			query_prover.prove_query(index, transcript.decommitment())?;
504		}
505
506		Ok(())
507	}
508}
509
510/// A prover for the FRI query phase.
511pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
512where
513	F: BinaryField,
514	FA: BinaryField,
515	P: PackedField<Scalar = F>,
516	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
517	VCS: MerkleTreeScheme<F>,
518{
519	params: &'a FRIParams<F, FA>,
520	codeword: &'a [P],
521	codeword_committed: &'a MerkleProver::Committed,
522	round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
523	merkle_prover: &'a MerkleProver,
524}
525
526impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
527where
528	F: TowerField + ExtensionField<FA>,
529	FA: BinaryField,
530	P: PackedField<Scalar = F>,
531	MerkleProver: MerkleTreeProver<F, Scheme = VCS>,
532	VCS: MerkleTreeScheme<F>,
533{
534	/// Number of oracles sent during the fold rounds.
535	pub fn n_oracles(&self) -> usize {
536		self.params.n_oracles()
537	}
538
539	/// Proves a FRI challenge query.
540	///
541	/// ## Arguments
542	///
543	/// * `index` - an index into the original codeword domain
544	#[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")]
545	pub fn prove_query<B>(
546		&self,
547		mut index: usize,
548		mut advice: TranscriptWriter<B>,
549	) -> Result<(), Error>
550	where
551		B: BufMut,
552	{
553		let mut arities_and_optimal_layers_depths = self
554			.params
555			.fold_arities()
556			.iter()
557			.copied()
558			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()));
559
560		let Some((first_fold_arity, first_optimal_layer_depth)) =
561			arities_and_optimal_layers_depths.next()
562		else {
563			// If there are no query proofs, that means that no oracles were sent during the FRI
564			// fold rounds. In that case, the original interleaved codeword is decommitted and
565			// the only checks that need to be performed are in `verify_last_oracle`.
566			return Ok(());
567		};
568
569		prove_coset_opening(
570			self.merkle_prover,
571			self.codeword,
572			self.codeword_committed,
573			index,
574			first_fold_arity,
575			first_optimal_layer_depth,
576			&mut advice,
577		)?;
578
579		for ((codeword, committed), (arity, optimal_layer_depth)) in
580			izip!(self.round_committed.iter(), arities_and_optimal_layers_depths)
581		{
582			index >>= arity;
583			prove_coset_opening(
584				self.merkle_prover,
585				codeword,
586				committed,
587				index,
588				arity,
589				optimal_layer_depth,
590				&mut advice,
591			)?;
592		}
593
594		Ok(())
595	}
596
597	pub fn vcs_optimal_layers(&self) -> Result<Vec<Vec<VCS::Digest>>, Error> {
598		let committed_iter = std::iter::once(self.codeword_committed)
599			.chain(self.round_committed.iter().map(|(_, committed)| committed));
600
601		committed_iter
602			.zip(vcs_optimal_layers_depths_iter(self.params, self.merkle_prover.scheme()))
603			.map(|(committed, optimal_layer_depth)| {
604				self.merkle_prover
605					.layer(committed, optimal_layer_depth)
606					.map(|layer| layer.to_vec())
607					.map_err(|err| Error::VectorCommit(Box::new(err)))
608			})
609			.collect::<Result<Vec<_>, _>>()
610	}
611}
612
613fn prove_coset_opening<F, P, MTProver, B>(
614	merkle_prover: &MTProver,
615	codeword: &[P],
616	committed: &MTProver::Committed,
617	coset_index: usize,
618	log_coset_size: usize,
619	optimal_layer_depth: usize,
620	advice: &mut TranscriptWriter<B>,
621) -> Result<(), Error>
622where
623	F: TowerField,
624	P: PackedField<Scalar = F>,
625	MTProver: MerkleTreeProver<F>,
626	B: BufMut,
627{
628	let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
629		.take(1 << log_coset_size);
630	advice.write_scalar_iter(values);
631
632	merkle_prover
633		.prove_opening(committed, optimal_layer_depth, coset_index, advice)
634		.map_err(|err| Error::VectorCommit(Box::new(err)))?;
635
636	Ok(())
637}