binius_core/piop/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{borrow::Cow, ops::Deref};
4
5use binius_compute::{
6	ComputeData, ComputeLayer, ComputeMemory, FSlice, SizedSlice, alloc::ComputeAllocator,
7	cpu::CpuMemory,
8};
9use binius_field::{
10	BinaryField, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
11	packed::PackedSliceMut,
12};
13use binius_math::{MLEDirectAdapter, MultilinearExtension, MultilinearPoly};
14use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
15use binius_ntt::AdditiveNTT;
16use binius_utils::{
17	SerializeBytes, bail,
18	checked_arithmetics::checked_log_2,
19	random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
20	sorting::is_sorted_ascending,
21};
22use bytemuck::zeroed_vec;
23use itertools::{Itertools, chain};
24
25use super::{
26	error::Error,
27	verify::{PIOPSumcheckClaim, make_sumcheck_claim_descs},
28};
29use crate::{
30	fiat_shamir::{CanSample, Challenger},
31	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
32	oracle::OracleId,
33	piop::{
34		CommitMeta,
35		logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
36	},
37	protocols::{
38		fri::{self, FRIFolder, FRIParams, FoldRoundOutput},
39		sumcheck::{
40			self, SumcheckClaim,
41			prove::{SumcheckProver, front_loaded::BatchProver as SumcheckBatchProver},
42			v3::bivariate_product::BivariateSumcheckProver,
43		},
44	},
45	transcript::ProverTranscript,
46};
47
48#[inline(always)]
49fn reverse_bits(x: usize, log_len: usize) -> usize {
50	x.reverse_bits()
51		.wrapping_shr((usize::BITS as usize - log_len) as _)
52}
53
54/// Reorders the scalars in a slice of packed field elements by reversing the bits of their indices.
55/// TODO: investigate if we can optimize this.
56fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
57	let log_len = checked_log_2(collection.len());
58	for i in 0..collection.len() {
59		let bit_reversed_index = reverse_bits(i, log_len);
60		if i < bit_reversed_index {
61			// Safety: `i` and `j` are guaranteed to be in bounds of the slice
62			unsafe {
63				let tmp = collection.get_unchecked(i);
64				collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
65				collection.set_unchecked(bit_reversed_index, tmp);
66			}
67		}
68	}
69}
70
71// ## Preconditions
72//
73// * all multilinears in `multilins` have at least log_extension_degree packed variables
74// * all multilinears in `multilins` have `packed_evals()` is Some
75// * multilinears are sorted in ascending order by number of packed variables
76// * `message_buffer` is initialized to all zeros
77// * `message_buffer` is larger than the total number of scalars in the multilinears
78fn merge_multilins<F, P, Data>(
79	multilins: &[MultilinearExtension<P, Data>],
80	message_buffer: &mut [P],
81) where
82	F: TowerField,
83	P: PackedField<Scalar = F>,
84	Data: Deref<Target = [P]>,
85{
86	let mut mle_iter = multilins.iter().rev();
87
88	// First copy all the polynomials where the number of elements is a multiple of the packing
89	// width.
90	let mut full_packed_mles = Vec::new(); // (evals, corresponding buffer where to copy)
91	let mut remaining_buffer = message_buffer;
92	for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
93		let evals = mle.evals();
94		let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
95		full_packed_mles.push((evals, chunk));
96		remaining_buffer = rest;
97	}
98	full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
99		chunk.copy_from_slice(evals);
100		reverse_index_bits(&mut PackedSliceMut::new(chunk));
101	});
102
103	// Now copy scalars from the remaining multilinears, which have too few elements to copy full
104	// packed elements.
105	let mut scalar_offset = 0;
106	let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
107	for mle in mle_iter {
108		let packed_eval = mle.evals()[0];
109		let len = 1 << mle.n_vars();
110		let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
111		for i in 0..len {
112			packed_chunk.set(i, packed_eval.get(i));
113		}
114		reverse_index_bits(&mut packed_chunk);
115
116		scalar_offset += len;
117	}
118}
119
120/// Commits a batch of multilinear polynomials.
121///
122/// The multilinears this function accepts as arguments may be defined over subfields of `F`. In
123/// this case, we commit to these multilinears by instead committing to their "packed"
124/// multilinears. These are the multilinear extensions of their packed coefficients over subcubes
125/// of the size of the extension degree.
126///
127/// ## Arguments
128///
129/// * `fri_params` - the FRI parameters for the commitment opening protocol
130/// * `merkle_prover` - the Merkle tree prover used in FRI
131/// * `multilins` - a batch of multilinear polynomials to commit. The multilinears provided may be
132///   defined over subfields of `F`. They must be in ascending order by the number of variables in
133///   the packed multilinear (ie. number of variables minus log extension degree).
134pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
135	fri_params: &FRIParams<F, FEncode>,
136	ntt: &NTT,
137	merkle_prover: &MTProver,
138	multilins: &[M],
139) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
140where
141	F: TowerField,
142	FEncode: BinaryField,
143	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
144	M: MultilinearPoly<P>,
145	NTT: AdditiveNTT<FEncode> + Sync,
146	MTScheme: MerkleTreeScheme<F>,
147	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
148{
149	let packed_multilins = multilins
150		.iter()
151		.enumerate()
152		.map(|(i, unpacked_committed)| {
153			packed_committed(OracleId::from_index(i), unpacked_committed)
154		})
155		.collect::<Result<Vec<_>, _>>()?;
156	if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
157		return Err(Error::CommittedsNotSorted);
158	}
159
160	let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
161		merge_multilins(&packed_multilins, message_buffer)
162	})?;
163
164	Ok(output)
165}
166
167/// Proves a batch of sumcheck claims that are products of committed polynomials from a committed
168/// batch and transparent polynomials.
169///
170/// The arguments corresponding to the committed multilinears must be the output of [`commit`].
171#[allow(clippy::too_many_arguments)]
172pub fn prove<
173	Hal,
174	F,
175	FEncode,
176	P,
177	M,
178	NTT,
179	MTScheme,
180	MTProver,
181	Challenger_,
182	HostComputeAllocatorType,
183	DeviceComputeAllocatorType,
184>(
185	compute_data: &ComputeData<F, Hal, HostComputeAllocatorType, DeviceComputeAllocatorType>,
186	fri_params: &FRIParams<F, FEncode>,
187	ntt: &NTT,
188	merkle_prover: &MTProver,
189	commit_meta: &CommitMeta,
190	committed: MTProver::Committed,
191	codeword: &[P],
192	committed_multilins: &[M],
193	transparent_multilins: Vec<FSlice<'_, F, Hal>>,
194	claims: &[PIOPSumcheckClaim<F>],
195	transcript: &mut ProverTranscript<Challenger_>,
196) -> Result<(), Error>
197where
198	F: TowerField,
199	FEncode: BinaryField,
200	P: PackedField<Scalar = F>
201		+ PackedExtension<F, PackedSubfield = P>
202		+ PackedExtension<FEncode>
203		+ PackedFieldIndexable<Scalar = F>,
204	M: MultilinearPoly<P> + Send + Sync,
205	NTT: AdditiveNTT<FEncode> + Sync,
206	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
207	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
208	Challenger_: Challenger,
209	Hal: ComputeLayer<F>,
210	HostComputeAllocatorType: ComputeAllocator<F, CpuMemory>,
211	DeviceComputeAllocatorType: ComputeAllocator<F, Hal::DevMem>,
212{
213	let host_alloc = &compute_data.host_alloc;
214	let dev_alloc = &compute_data.dev_alloc;
215	let hal = compute_data.hal;
216
217	let sumcheck_claim_descs = make_sumcheck_claim_descs(
218		commit_meta,
219		transparent_multilins
220			.iter()
221			.map(|poly| checked_log_2(poly.len())),
222		claims,
223	)?;
224
225	// The committed multilinears provided by argument are committed *small field* multilinears.
226	// Create multilinears representing the packed polynomials here. Eventually, we would like to
227	// refactor the calling code so that the PIOP only handles *big field* multilinear witnesses.
228	let packed_committed_multilins = committed_multilins
229		.iter()
230		.enumerate()
231		.map(|(i, unpacked_committed)| {
232			packed_committed(OracleId::from_index(i), unpacked_committed)
233				.map(MLEDirectAdapter::from)
234		})
235		.collect::<Result<Vec<_>, _>>()?;
236
237	let copy_span = tracing::debug_span!(
238		"[task] Copy polynomials to device memory",
239		phase = "piop_compiler",
240		perfetto_category = "phase.sub",
241	)
242	.entered();
243	let packed_committed_fslices = packed_committed_multilins
244		.iter()
245		.map(|packed_committed_multilin| {
246			let hypercube_evals = packed_committed_multilin
247				.packed_evals()
248				.expect("Prover should always populate witnesses");
249			let unpacked_hypercube_evals = P::unpack_scalars(hypercube_evals);
250			let mut allocated_mem = dev_alloc.alloc(1 << packed_committed_multilin.n_vars())?;
251
252			hal.copy_h2d(
253				&unpacked_hypercube_evals[0..1 << packed_committed_multilin.n_vars()],
254				&mut allocated_mem,
255			)?;
256			Ok(Hal::DevMem::to_const(allocated_mem))
257		})
258		.collect::<Result<Vec<_>, Error>>()?;
259
260	drop(copy_span);
261
262	let non_empty_sumcheck_descs = sumcheck_claim_descs
263		.iter()
264		.enumerate()
265		// Keep sumcheck claims with >0 committed multilinears, even with 0 composite claims. This
266		// indicates unconstrained columns, but we still need the final evaluations from the
267		// sumcheck prover in order to derive the final FRI value.
268		.filter(|(_, desc)| !desc.committed_indices.is_empty());
269
270	let mut sumcheck_provers = vec![];
271
272	for (n_vars, desc) in non_empty_sumcheck_descs {
273		let multilins = chain!(
274			packed_committed_fslices[desc.committed_indices.clone()]
275				.iter()
276				.map(|fslice| Hal::DevMem::narrow(fslice)),
277			transparent_multilins[desc.transparent_indices.clone()]
278				.iter()
279				.map(|fslice| Hal::DevMem::narrow(fslice))
280		)
281		.collect::<Vec<_>>();
282
283		let claim = SumcheckClaim::new(n_vars, multilins.len(), desc.composite_sums.clone())?;
284
285		sumcheck_provers
286			.push(BivariateSumcheckProver::new(hal, dev_alloc, host_alloc, &claim, multilins)?);
287	}
288
289	prove_interleaved_fri_sumcheck(
290		hal,
291		dev_alloc,
292		commit_meta.total_vars(),
293		fri_params,
294		ntt,
295		merkle_prover,
296		sumcheck_provers,
297		codeword,
298		&committed,
299		transcript,
300	)?;
301
302	Ok(())
303}
304
305#[allow(clippy::too_many_arguments)]
306fn prove_interleaved_fri_sumcheck<Hal, F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
307	hal: &Hal,
308	dev_alloc: &impl ComputeAllocator<F, Hal::DevMem>,
309	n_rounds: usize,
310	fri_params: &FRIParams<F, FEncode>,
311	ntt: &NTT,
312	merkle_prover: &MTProver,
313	sumcheck_provers: Vec<impl SumcheckProver<F>>,
314	codeword: &[P],
315	committed: &MTProver::Committed,
316	transcript: &mut ProverTranscript<Challenger_>,
317) -> Result<(), Error>
318where
319	Hal: ComputeLayer<F>,
320	F: TowerField,
321	FEncode: BinaryField,
322	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
323	NTT: AdditiveNTT<FEncode> + Sync,
324	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
325	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
326	Challenger_: Challenger,
327{
328	let mut fri_prover = FRIFolder::new(hal, fri_params, ntt, merkle_prover, codeword, committed)?;
329
330	let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
331
332	for round in 0..n_rounds {
333		let _span =
334			tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
335				.entered();
336
337		let bivariate_sumcheck_span = tracing::debug_span!(
338			"[step] Bivariate Sumcheck",
339			phase = "piop_compiler",
340			round = round,
341			perfetto_category = "phase.sub"
342		)
343		.entered();
344		let provers_dimensions_data =
345			SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
346		let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
347			"[task] (PIOP Compiler) Calculate Coeffs",
348			phase = "piop_compiler",
349			round = round,
350			perfetto_category = "task.main",
351			dimensions_data = ?provers_dimensions_data,
352		)
353		.entered();
354
355		sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
356		drop(bivariate_sumcheck_calculate_coeffs_span);
357
358		let challenge = transcript.sample();
359		let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
360			"[task] (PIOP Compiler) Fold (All Rounds)",
361			phase = "piop_compiler",
362			round = round,
363			perfetto_category = "task.main",
364			dimensions_data = ?provers_dimensions_data,
365		)
366		.entered();
367		sumcheck_batch_prover.receive_challenge(challenge)?;
368		drop(bivariate_sumcheck_all_folds_span);
369		drop(bivariate_sumcheck_span);
370
371		let dimensions_data = FriFoldRoundsData::new(
372			round,
373			fri_params.log_batch_size(),
374			fri_prover.current_codeword_len(),
375		);
376		let fri_fold_rounds_span = tracing::debug_span!(
377			"[step] FRI Fold Rounds",
378			phase = "piop_compiler",
379			round = round,
380			perfetto_category = "phase.sub",
381			?dimensions_data,
382		)
383		.entered();
384		match fri_prover.execute_fold_round(dev_alloc, challenge)? {
385			FoldRoundOutput::NoCommitment => {}
386			FoldRoundOutput::Commitment(round_commitment) => {
387				transcript.message().write(&round_commitment);
388			}
389		}
390		drop(fri_fold_rounds_span);
391	}
392
393	sumcheck_batch_prover.finish(&mut transcript.message())?;
394	fri_prover.finish_proof(transcript)?;
395	Ok(())
396}
397
398pub fn validate_sumcheck_witness<'a, F, P, M, Hal: ComputeLayer<F>>(
399	committed_multilins: &[M],
400	transparent_multilins: &[FSlice<'a, F, Hal>],
401	claims: &[PIOPSumcheckClaim<F>],
402	hal: &Hal,
403) -> Result<(), Error>
404where
405	F: TowerField,
406	P: PackedField<Scalar = F>,
407	M: MultilinearPoly<P> + Send + Sync,
408{
409	let packed_committed = committed_multilins
410		.iter()
411		.enumerate()
412		.map(|(i, unpacked_committed)| {
413			packed_committed(OracleId::from_index(i), unpacked_committed)
414		})
415		.collect::<Result<Vec<_>, _>>()?;
416
417	for (i, claim) in claims.iter().enumerate() {
418		let committed = &packed_committed[claim.committed];
419		if committed.n_vars() != claim.n_vars {
420			bail!(sumcheck::Error::NumberOfVariablesMismatch);
421		}
422
423		let transparent = &transparent_multilins[claim.transparent];
424		if transparent.len() != 1 << claim.n_vars {
425			bail!(sumcheck::Error::NumberOfVariablesMismatch);
426		}
427
428		let mut transparent_evals = zeroed_vec(transparent.len());
429
430		hal.copy_d2h(*transparent, &mut transparent_evals)?;
431
432		let sum = (0..(1 << claim.n_vars))
433			.into_par_iter()
434			.map(|j| {
435				let committed_eval = committed
436					.evaluate_on_hypercube(j)
437					.expect("j is less than 1 << n_vars; committed.n_vars is checked above");
438				let transparent_eval = transparent_evals
439					.get(j)
440					.expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
441				committed_eval * transparent_eval
442			})
443			.sum::<F>();
444
445		if sum != claim.sum {
446			bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
447				composition_index: i,
448			});
449		}
450	}
451	Ok(())
452}
453
454/// Creates a multilinear extension of the packed evaluations of a small-field multilinear.
455///
456/// Given a multilinear $P \in T_{\iota}[X_0, \ldots, X_{n-1}]$, this creates the multilinear
457/// extension $\hat{P} \in T_{\tau}[X_0, \ldots, X_{n - \kappa - 1}]$. In the case where
458/// $n < \kappa$, which is when a polynomial is too full to have even a single packed evaluation,
459/// the polynomial is extended by padding with more variables, which corresponds to repeating its
460/// subcube evaluations.
461fn packed_committed<F, P, M>(
462	id: OracleId,
463	unpacked_committed: &M,
464) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
465where
466	F: TowerField,
467	P: PackedField<Scalar = F>,
468	M: MultilinearPoly<P>,
469{
470	let unpacked_n_vars = unpacked_committed.n_vars();
471	let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
472		let packed_eval = padded_packed_eval(unpacked_committed);
473		MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
474	} else {
475		let packed_evals = unpacked_committed
476			.packed_evals()
477			.ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
478
479		MultilinearExtension::new(
480			unpacked_n_vars - unpacked_committed.log_extension_degree(),
481			Cow::Borrowed(packed_evals),
482		)
483	}?;
484	Ok(packed_committed)
485}
486
487#[inline]
488fn padded_packed_eval<F, P, M>(multilin: &M) -> F
489where
490	F: TowerField,
491	P: PackedField<Scalar = F>,
492	M: MultilinearPoly<P>,
493{
494	let n_vars = multilin.n_vars();
495	let kappa = multilin.log_extension_degree();
496	assert!(n_vars < kappa);
497
498	(0..1 << kappa)
499		.map(|i| {
500			let iota = F::TOWER_LEVEL - kappa;
501			let scalar = <F as TowerField>::basis(iota, i)
502				.expect("i is in range 0..1 << log_extension_degree");
503			multilin
504				.evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
505				.expect("i is in range 0..1 << n_vars")
506		})
507		.sum()
508}
509
510#[cfg(test)]
511mod tests {
512	use std::iter::repeat_with;
513
514	use binius_field::PackedBinaryField2x128b;
515	use rand::{SeedableRng, rngs::StdRng};
516
517	use super::*;
518
519	#[test]
520	fn test_merge_multilins() {
521		let mut rng = StdRng::seed_from_u64(0);
522
523		let multilins = (0usize..8)
524			.map(|n_vars| {
525				let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
526					.take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
527					.collect::<Vec<_>>();
528
529				MultilinearExtension::new(n_vars, data).unwrap()
530			})
531			.collect::<Vec<_>>();
532		let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
533		let mut buffer =
534			vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
535		merge_multilins(&multilins, &mut buffer);
536
537		let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
538		let mut offset = 0;
539		for multilin in multilins.iter().rev() {
540			let scalars = &scalars[offset..];
541			for (i, v) in PackedField::iter_slice(multilin.evals())
542				.take(1 << multilin.n_vars())
543				.enumerate()
544			{
545				assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
546			}
547			offset += 1 << multilin.n_vars();
548		}
549	}
550}