binius_core/piop/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{borrow::Cow, ops::Deref};
4
5use binius_field::{
6	BinaryField, Field, PackedExtension, PackedField, TowerField, packed::PackedSliceMut,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{
10	EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
11	MultilinearPoly,
12};
13use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
14use binius_ntt::AdditiveNTT;
15use binius_utils::{
16	SerializeBytes, bail,
17	checked_arithmetics::checked_log_2,
18	random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
19	sorting::is_sorted_ascending,
20};
21use either::Either;
22use itertools::{Itertools, chain};
23
24use super::{
25	error::Error,
26	verify::{PIOPSumcheckClaim, make_sumcheck_claim_descs},
27};
28use crate::{
29	fiat_shamir::{CanSample, Challenger},
30	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
31	oracle::OracleId,
32	piop::{
33		CommitMeta,
34		logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
35	},
36	protocols::{
37		fri::{self, FRIFolder, FRIParams, FoldRoundOutput},
38		sumcheck::{
39			self, immediate_switchover_heuristic,
40			prove::{
41				RegularSumcheckProver, SumcheckProver,
42				front_loaded::BatchProver as SumcheckBatchProver,
43			},
44		},
45	},
46	transcript::ProverTranscript,
47};
48
49#[inline(always)]
50fn reverse_bits(x: usize, log_len: usize) -> usize {
51	x.reverse_bits()
52		.wrapping_shr((usize::BITS as usize - log_len) as _)
53}
54
55/// Reorders the scalars in a slice of packed field elements by reversing the bits of their indices.
56/// TODO: investigate if we can optimize this.
57fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
58	let log_len = checked_log_2(collection.len());
59	for i in 0..collection.len() {
60		let bit_reversed_index = reverse_bits(i, log_len);
61		if i < bit_reversed_index {
62			// Safety: `i` and `j` are guaranteed to be in bounds of the slice
63			unsafe {
64				let tmp = collection.get_unchecked(i);
65				collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
66				collection.set_unchecked(bit_reversed_index, tmp);
67			}
68		}
69	}
70}
71
72// ## Preconditions
73//
74// * all multilinears in `multilins` have at least log_extension_degree packed variables
75// * all multilinears in `multilins` have `packed_evals()` is Some
76// * multilinears are sorted in ascending order by number of packed variables
77// * `message_buffer` is initialized to all zeros
78// * `message_buffer` is larger than the total number of scalars in the multilinears
79fn merge_multilins<F, P, Data>(
80	multilins: &[MultilinearExtension<P, Data>],
81	message_buffer: &mut [P],
82) where
83	F: TowerField,
84	P: PackedField<Scalar = F>,
85	Data: Deref<Target = [P]>,
86{
87	let mut mle_iter = multilins.iter().rev();
88
89	// First copy all the polynomials where the number of elements is a multiple of the packing
90	// width.
91	let mut full_packed_mles = Vec::new(); // (evals, corresponding buffer where to copy)
92	let mut remaining_buffer = message_buffer;
93	for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
94		let evals = mle.evals();
95		let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
96		full_packed_mles.push((evals, chunk));
97		remaining_buffer = rest;
98	}
99	full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
100		chunk.copy_from_slice(evals);
101		reverse_index_bits(&mut PackedSliceMut::new(chunk));
102	});
103
104	// Now copy scalars from the remaining multilinears, which have too few elements to copy full
105	// packed elements.
106	let mut scalar_offset = 0;
107	let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
108	for mle in mle_iter {
109		let packed_eval = mle.evals()[0];
110		let len = 1 << mle.n_vars();
111		let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
112		for i in 0..len {
113			packed_chunk.set(i, packed_eval.get(i));
114		}
115		reverse_index_bits(&mut packed_chunk);
116
117		scalar_offset += len;
118	}
119}
120
121/// Commits a batch of multilinear polynomials.
122///
123/// The multilinears this function accepts as arguments may be defined over subfields of `F`. In
124/// this case, we commit to these multilinears by instead committing to their "packed"
125/// multilinears. These are the multilinear extensions of their packed coefficients over subcubes
126/// of the size of the extension degree.
127///
128/// ## Arguments
129///
130/// * `fri_params` - the FRI parameters for the commitment opening protocol
131/// * `merkle_prover` - the Merkle tree prover used in FRI
132/// * `multilins` - a batch of multilinear polynomials to commit. The multilinears provided may be
133///   defined over subfields of `F`. They must be in ascending order by the number of variables in
134///   the packed multilinear (ie. number of variables minus log extension degree).
135pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
136	fri_params: &FRIParams<F, FEncode>,
137	ntt: &NTT,
138	merkle_prover: &MTProver,
139	multilins: &[M],
140) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
141where
142	F: TowerField,
143	FEncode: BinaryField,
144	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
145	M: MultilinearPoly<P>,
146	NTT: AdditiveNTT<FEncode> + Sync,
147	MTScheme: MerkleTreeScheme<F>,
148	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
149{
150	let packed_multilins = multilins
151		.iter()
152		.enumerate()
153		.map(|(i, unpacked_committed)| {
154			packed_committed(OracleId::from_index(i), unpacked_committed)
155		})
156		.collect::<Result<Vec<_>, _>>()?;
157	if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
158		return Err(Error::CommittedsNotSorted);
159	}
160
161	let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
162		merge_multilins(&packed_multilins, message_buffer)
163	})?;
164
165	Ok(output)
166}
167
168/// Proves a batch of sumcheck claims that are products of committed polynomials from a committed
169/// batch and transparent polynomials.
170///
171/// The arguments corresponding to the committed multilinears must be the output of [`commit`].
172#[allow(clippy::too_many_arguments)]
173pub fn prove<
174	F,
175	FDomain,
176	FEncode,
177	P,
178	M,
179	NTT,
180	DomainFactory,
181	MTScheme,
182	MTProver,
183	Challenger_,
184	Backend,
185>(
186	fri_params: &FRIParams<F, FEncode>,
187	ntt: &NTT,
188	merkle_prover: &MTProver,
189	domain_factory: DomainFactory,
190	commit_meta: &CommitMeta,
191	committed: MTProver::Committed,
192	codeword: &[P],
193	committed_multilins: &[M],
194	transparent_multilins: &[M],
195	claims: &[PIOPSumcheckClaim<F>],
196	transcript: &mut ProverTranscript<Challenger_>,
197	backend: &Backend,
198) -> Result<(), Error>
199where
200	F: TowerField,
201	FDomain: Field,
202	FEncode: BinaryField,
203	P: PackedField<Scalar = F>
204		+ PackedExtension<F, PackedSubfield = P>
205		+ PackedExtension<FDomain>
206		+ PackedExtension<FEncode>,
207	M: MultilinearPoly<P> + Send + Sync,
208	NTT: AdditiveNTT<FEncode> + Sync,
209	DomainFactory: EvaluationDomainFactory<FDomain>,
210	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212	Challenger_: Challenger,
213	Backend: ComputationBackend,
214{
215	// Map of n_vars to sumcheck claim descriptions
216	let sumcheck_claim_descs = make_sumcheck_claim_descs(
217		commit_meta,
218		transparent_multilins.iter().map(|poly| poly.n_vars()),
219		claims,
220	)?;
221
222	// The committed multilinears provided by argument are committed *small field* multilinears.
223	// Create multilinears representing the packed polynomials here. Eventually, we would like to
224	// refactor the calling code so that the PIOP only handles *big field* multilinear witnesses.
225	let packed_committed_multilins = committed_multilins
226		.iter()
227		.enumerate()
228		.map(|(i, unpacked_committed)| {
229			packed_committed(OracleId::from_index(i), unpacked_committed)
230				.map(MLEDirectAdapter::from)
231		})
232		.collect::<Result<Vec<_>, _>>()?;
233
234	let non_empty_sumcheck_descs = sumcheck_claim_descs
235		.iter()
236		.enumerate()
237		// Keep sumcheck claims with >0 committed multilinears, even with 0 composite claims. This
238		// indicates unconstrained columns, but we still need the final evaluations from the
239		// sumcheck prover in order to derive the final FRI value.
240		.filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
241	let sumcheck_provers = non_empty_sumcheck_descs
242		.map(|(_n_vars, desc)| {
243			let multilins = chain!(
244				packed_committed_multilins[desc.committed_indices.clone()]
245					.iter()
246					.map(Either::Left),
247				transparent_multilins[desc.transparent_indices.clone()]
248					.iter()
249					.map(Either::Right),
250			)
251			.collect::<Vec<_>>();
252			RegularSumcheckProver::new(
253				EvaluationOrder::HighToLow,
254				multilins,
255				desc.composite_sums.iter().cloned(),
256				&domain_factory,
257				immediate_switchover_heuristic,
258				backend,
259			)
260		})
261		.collect::<Result<Vec<_>, _>>()?;
262
263	prove_interleaved_fri_sumcheck(
264		commit_meta.total_vars(),
265		fri_params,
266		ntt,
267		merkle_prover,
268		sumcheck_provers,
269		codeword,
270		&committed,
271		transcript,
272	)?;
273
274	Ok(())
275}
276
277#[allow(clippy::too_many_arguments)]
278fn prove_interleaved_fri_sumcheck<F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
279	n_rounds: usize,
280	fri_params: &FRIParams<F, FEncode>,
281	ntt: &NTT,
282	merkle_prover: &MTProver,
283	sumcheck_provers: Vec<impl SumcheckProver<F>>,
284	codeword: &[P],
285	committed: &MTProver::Committed,
286	transcript: &mut ProverTranscript<Challenger_>,
287) -> Result<(), Error>
288where
289	F: TowerField,
290	FEncode: BinaryField,
291	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
292	NTT: AdditiveNTT<FEncode> + Sync,
293	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
294	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
295	Challenger_: Challenger,
296{
297	let mut fri_prover = FRIFolder::new(fri_params, ntt, merkle_prover, codeword, committed)?;
298
299	let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
300
301	for round in 0..n_rounds {
302		let _span =
303			tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
304				.entered();
305
306		let bivariate_sumcheck_span = tracing::debug_span!(
307			"[step] Bivariate Sumcheck",
308			phase = "piop_compiler",
309			round = round,
310			perfetto_category = "phase.sub"
311		)
312		.entered();
313		let provers_dimensions_data =
314			SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
315		let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
316			"[task] (PIOP Compiler) Calculate Coeffs",
317			phase = "piop_compiler",
318			round = round,
319			perfetto_category = "task.main",
320			dimensions_data = ?provers_dimensions_data,
321		)
322		.entered();
323		sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
324		drop(bivariate_sumcheck_calculate_coeffs_span);
325
326		let challenge = transcript.sample();
327		let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
328			"[task] (PIOP Compiler) Fold (All Rounds)",
329			phase = "piop_compiler",
330			round = round,
331			perfetto_category = "task.main",
332			dimensions_data = ?provers_dimensions_data,
333		)
334		.entered();
335		sumcheck_batch_prover.receive_challenge(challenge)?;
336		drop(bivariate_sumcheck_all_folds_span);
337		drop(bivariate_sumcheck_span);
338
339		let dimensions_data = FriFoldRoundsData::new(
340			round,
341			fri_params.log_batch_size(),
342			fri_prover.current_codeword_len(),
343		);
344		let fri_fold_rounds_span = tracing::debug_span!(
345			"[step] FRI Fold Rounds",
346			phase = "piop_compiler",
347			round = round,
348			perfetto_category = "phase.sub",
349			?dimensions_data,
350		)
351		.entered();
352		match fri_prover.execute_fold_round(challenge)? {
353			FoldRoundOutput::NoCommitment => {}
354			FoldRoundOutput::Commitment(round_commitment) => {
355				transcript.message().write(&round_commitment);
356			}
357		}
358		drop(fri_fold_rounds_span);
359	}
360
361	sumcheck_batch_prover.finish(&mut transcript.message())?;
362	fri_prover.finish_proof(transcript)?;
363	Ok(())
364}
365
366pub fn validate_sumcheck_witness<F, P, M>(
367	committed_multilins: &[M],
368	transparent_multilins: &[M],
369	claims: &[PIOPSumcheckClaim<F>],
370) -> Result<(), Error>
371where
372	F: TowerField,
373	P: PackedField<Scalar = F>,
374	M: MultilinearPoly<P> + Send + Sync,
375{
376	let packed_committed = committed_multilins
377		.iter()
378		.enumerate()
379		.map(|(i, unpacked_committed)| {
380			packed_committed(OracleId::from_index(i), unpacked_committed)
381		})
382		.collect::<Result<Vec<_>, _>>()?;
383
384	for (i, claim) in claims.iter().enumerate() {
385		let committed = &packed_committed[claim.committed];
386		if committed.n_vars() != claim.n_vars {
387			bail!(sumcheck::Error::NumberOfVariablesMismatch);
388		}
389
390		let transparent = &transparent_multilins[claim.transparent];
391		if transparent.n_vars() != claim.n_vars {
392			bail!(sumcheck::Error::NumberOfVariablesMismatch);
393		}
394
395		let sum = (0..(1 << claim.n_vars))
396			.into_par_iter()
397			.map(|j| {
398				let committed_eval = committed
399					.evaluate_on_hypercube(j)
400					.expect("j is less than 1 << n_vars; committed.n_vars is checked above");
401				let transparent_eval = transparent
402					.evaluate_on_hypercube(j)
403					.expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
404				committed_eval * transparent_eval
405			})
406			.sum::<F>();
407
408		if sum != claim.sum {
409			bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
410				composition_index: i,
411			});
412		}
413	}
414	Ok(())
415}
416
417/// Creates a multilinear extension of the packed evaluations of a small-field multilinear.
418///
419/// Given a multilinear $P \in T_{\iota}[X_0, \ldots, X_{n-1}]$, this creates the multilinear
420/// extension $\hat{P} \in T_{\tau}[X_0, \ldots, X_{n - \kappa - 1}]$. In the case where
421/// $n < \kappa$, which is when a polynomial is too full to have even a single packed evaluation,
422/// the polynomial is extended by padding with more variables, which corresponds to repeating its
423/// subcube evaluations.
424fn packed_committed<F, P, M>(
425	id: OracleId,
426	unpacked_committed: &M,
427) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
428where
429	F: TowerField,
430	P: PackedField<Scalar = F>,
431	M: MultilinearPoly<P>,
432{
433	let unpacked_n_vars = unpacked_committed.n_vars();
434	let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
435		let packed_eval = padded_packed_eval(unpacked_committed);
436		MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
437	} else {
438		let packed_evals = unpacked_committed
439			.packed_evals()
440			.ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
441
442		MultilinearExtension::new(
443			unpacked_n_vars - unpacked_committed.log_extension_degree(),
444			Cow::Borrowed(packed_evals),
445		)
446	}?;
447	Ok(packed_committed)
448}
449
450#[inline]
451fn padded_packed_eval<F, P, M>(multilin: &M) -> F
452where
453	F: TowerField,
454	P: PackedField<Scalar = F>,
455	M: MultilinearPoly<P>,
456{
457	let n_vars = multilin.n_vars();
458	let kappa = multilin.log_extension_degree();
459	assert!(n_vars < kappa);
460
461	(0..1 << kappa)
462		.map(|i| {
463			let iota = F::TOWER_LEVEL - kappa;
464			let scalar = <F as TowerField>::basis(iota, i)
465				.expect("i is in range 0..1 << log_extension_degree");
466			multilin
467				.evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
468				.expect("i is in range 0..1 << n_vars")
469		})
470		.sum()
471}
472
473#[cfg(test)]
474mod tests {
475	use std::iter::repeat_with;
476
477	use binius_field::PackedBinaryField2x128b;
478	use rand::{SeedableRng, rngs::StdRng};
479
480	use super::*;
481
482	#[test]
483	fn test_merge_multilins() {
484		let mut rng = StdRng::seed_from_u64(0);
485
486		let multilins = (0usize..8)
487			.map(|n_vars| {
488				let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
489					.take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
490					.collect::<Vec<_>>();
491
492				MultilinearExtension::new(n_vars, data).unwrap()
493			})
494			.collect::<Vec<_>>();
495		let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
496		let mut buffer =
497			vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
498		merge_multilins(&multilins, &mut buffer);
499
500		let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
501		let mut offset = 0;
502		for multilin in multilins.iter().rev() {
503			let scalars = &scalars[offset..];
504			for (i, v) in PackedField::iter_slice(multilin.evals())
505				.take(1 << multilin.n_vars())
506				.enumerate()
507			{
508				assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
509			}
510			offset += 1 << multilin.n_vars();
511		}
512	}
513}