binius_core/piop/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{borrow::Cow, ops::Deref};
4
5use binius_field::{
6	packed::PackedSliceMut, BinaryField, Field, PackedExtension, PackedField, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{
10	EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
11	MultilinearPoly,
12};
13use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
14use binius_ntt::AdditiveNTT;
15use binius_utils::{
16	bail,
17	checked_arithmetics::checked_log_2,
18	random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
19	sorting::is_sorted_ascending,
20	SerializeBytes,
21};
22use either::Either;
23use itertools::{chain, Itertools};
24
25use super::{
26	error::Error,
27	verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
28};
29use crate::{
30	fiat_shamir::{CanSample, Challenger},
31	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
32	piop::{
33		logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
34		CommitMeta,
35	},
36	protocols::{
37		fri,
38		fri::{FRIFolder, FRIParams, FoldRoundOutput},
39		sumcheck,
40		sumcheck::{
41			immediate_switchover_heuristic,
42			prove::{
43				front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
44				SumcheckProver,
45			},
46		},
47	},
48	transcript::ProverTranscript,
49};
50
51#[inline(always)]
52fn reverse_bits(x: usize, log_len: usize) -> usize {
53	x.reverse_bits()
54		.wrapping_shr((usize::BITS as usize - log_len) as _)
55}
56
57/// Reorders the scalars in a slice of packed field elements by reversing the bits of their indices.
58/// TODO: investigate if we can optimize this.
59fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
60	let log_len = checked_log_2(collection.len());
61	for i in 0..collection.len() {
62		let bit_reversed_index = reverse_bits(i, log_len);
63		if i < bit_reversed_index {
64			// Safety: `i` and `j` are guaranteed to be in bounds of the slice
65			unsafe {
66				let tmp = collection.get_unchecked(i);
67				collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
68				collection.set_unchecked(bit_reversed_index, tmp);
69			}
70		}
71	}
72}
73
74// ## Preconditions
75//
76// * all multilinears in `multilins` have at least log_extension_degree packed variables
77// * all multilinears in `multilins` have `packed_evals()` is Some
78// * multilinears are sorted in ascending order by number of packed variables
79// * `message_buffer` is initialized to all zeros
80// * `message_buffer` is larger than the total number of scalars in the multilinears
81fn merge_multilins<F, P, Data>(
82	multilins: &[MultilinearExtension<P, Data>],
83	message_buffer: &mut [P],
84) where
85	F: TowerField,
86	P: PackedField<Scalar = F>,
87	Data: Deref<Target = [P]>,
88{
89	let mut mle_iter = multilins.iter().rev();
90
91	// First copy all the polynomials where the number of elements is a multiple of the packing
92	// width.
93	let mut full_packed_mles = Vec::new(); // (evals, corresponding buffer where to copy)
94	let mut remaining_buffer = message_buffer;
95	for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
96		let evals = mle.evals();
97		let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
98		full_packed_mles.push((evals, chunk));
99		remaining_buffer = rest;
100	}
101	full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
102		chunk.copy_from_slice(evals);
103		reverse_index_bits(&mut PackedSliceMut::new(chunk));
104	});
105
106	// Now copy scalars from the remaining multilinears, which have too few elements to copy full
107	// packed elements.
108	let mut scalar_offset = 0;
109	let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
110	for mle in mle_iter {
111		let packed_eval = mle.evals()[0];
112		let len = 1 << mle.n_vars();
113		let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
114		for i in 0..len {
115			packed_chunk.set(i, packed_eval.get(i));
116		}
117		reverse_index_bits(&mut packed_chunk);
118
119		scalar_offset += len;
120	}
121}
122
123/// Commits a batch of multilinear polynomials.
124///
125/// The multilinears this function accepts as arguments may be defined over subfields of `F`. In
126/// this case, we commit to these multilinears by instead committing to their "packed"
127/// multilinears. These are the multilinear extensions of their packed coefficients over subcubes
128/// of the size of the extension degree.
129///
130/// ## Arguments
131///
132/// * `fri_params` - the FRI parameters for the commitment opening protocol
133/// * `merkle_prover` - the Merkle tree prover used in FRI
134/// * `multilins` - a batch of multilinear polynomials to commit. The multilinears provided may be
135///   defined over subfields of `F`. They must be in ascending order by the number of variables
136///   in the packed multilinear (ie. number of variables minus log extension degree).
137pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
138	fri_params: &FRIParams<F, FEncode>,
139	ntt: &NTT,
140	merkle_prover: &MTProver,
141	multilins: &[M],
142) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
143where
144	F: TowerField,
145	FEncode: BinaryField,
146	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
147	M: MultilinearPoly<P>,
148	NTT: AdditiveNTT<FEncode> + Sync,
149	MTScheme: MerkleTreeScheme<F>,
150	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
151{
152	let packed_multilins = multilins
153		.iter()
154		.enumerate()
155		.map(|(i, unpacked_committed)| packed_committed(i, unpacked_committed))
156		.collect::<Result<Vec<_>, _>>()?;
157	if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
158		return Err(Error::CommittedsNotSorted);
159	}
160
161	let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
162		merge_multilins(&packed_multilins, message_buffer)
163	})?;
164
165	Ok(output)
166}
167
168/// Proves a batch of sumcheck claims that are products of committed polynomials from a committed
169/// batch and transparent polynomials.
170///
171/// The arguments corresponding to the committed multilinears must be the output of [`commit`].
172#[allow(clippy::too_many_arguments)]
173pub fn prove<
174	F,
175	FDomain,
176	FEncode,
177	P,
178	M,
179	NTT,
180	DomainFactory,
181	MTScheme,
182	MTProver,
183	Challenger_,
184	Backend,
185>(
186	fri_params: &FRIParams<F, FEncode>,
187	ntt: &NTT,
188	merkle_prover: &MTProver,
189	domain_factory: DomainFactory,
190	commit_meta: &CommitMeta,
191	committed: MTProver::Committed,
192	codeword: &[P],
193	committed_multilins: &[M],
194	transparent_multilins: &[M],
195	claims: &[PIOPSumcheckClaim<F>],
196	transcript: &mut ProverTranscript<Challenger_>,
197	backend: &Backend,
198) -> Result<(), Error>
199where
200	F: TowerField,
201	FDomain: Field,
202	FEncode: BinaryField,
203	P: PackedField<Scalar = F>
204		+ PackedExtension<F, PackedSubfield = P>
205		+ PackedExtension<FDomain>
206		+ PackedExtension<FEncode>,
207	M: MultilinearPoly<P> + Send + Sync,
208	NTT: AdditiveNTT<FEncode> + Sync,
209	DomainFactory: EvaluationDomainFactory<FDomain>,
210	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212	Challenger_: Challenger,
213	Backend: ComputationBackend,
214{
215	// Map of n_vars to sumcheck claim descriptions
216	let sumcheck_claim_descs = make_sumcheck_claim_descs(
217		commit_meta,
218		transparent_multilins.iter().map(|poly| poly.n_vars()),
219		claims,
220	)?;
221
222	// The committed multilinears provided by argument are committed *small field* multilinears.
223	// Create multilinears representing the packed polynomials here. Eventually, we would like to
224	// refactor the calling code so that the PIOP only handles *big field* multilinear witnesses.
225	let packed_committed_multilins = committed_multilins
226		.iter()
227		.enumerate()
228		.map(|(i, unpacked_committed)| {
229			packed_committed(i, unpacked_committed).map(MLEDirectAdapter::from)
230		})
231		.collect::<Result<Vec<_>, _>>()?;
232
233	let non_empty_sumcheck_descs = sumcheck_claim_descs
234		.iter()
235		.enumerate()
236		// Keep sumcheck claims with >0 committed multilinears, even with 0 composite claims. This
237		// indicates unconstrained columns, but we still need the final evaluations from the
238		// sumcheck prover in order to derive the final FRI value.
239		.filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
240	let sumcheck_provers = non_empty_sumcheck_descs
241		.map(|(_n_vars, desc)| {
242			let multilins = chain!(
243				packed_committed_multilins[desc.committed_indices.clone()]
244					.iter()
245					.map(Either::Left),
246				transparent_multilins[desc.transparent_indices.clone()]
247					.iter()
248					.map(Either::Right),
249			)
250			.collect::<Vec<_>>();
251			RegularSumcheckProver::new(
252				EvaluationOrder::HighToLow,
253				multilins,
254				desc.composite_sums.iter().cloned(),
255				&domain_factory,
256				immediate_switchover_heuristic,
257				backend,
258			)
259		})
260		.collect::<Result<Vec<_>, _>>()?;
261
262	prove_interleaved_fri_sumcheck(
263		commit_meta.total_vars(),
264		fri_params,
265		ntt,
266		merkle_prover,
267		sumcheck_provers,
268		codeword,
269		&committed,
270		transcript,
271	)?;
272
273	Ok(())
274}
275
276#[allow(clippy::too_many_arguments)]
277fn prove_interleaved_fri_sumcheck<F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
278	n_rounds: usize,
279	fri_params: &FRIParams<F, FEncode>,
280	ntt: &NTT,
281	merkle_prover: &MTProver,
282	sumcheck_provers: Vec<impl SumcheckProver<F>>,
283	codeword: &[P],
284	committed: &MTProver::Committed,
285	transcript: &mut ProverTranscript<Challenger_>,
286) -> Result<(), Error>
287where
288	F: TowerField,
289	FEncode: BinaryField,
290	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
291	NTT: AdditiveNTT<FEncode> + Sync,
292	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
293	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
294	Challenger_: Challenger,
295{
296	let mut fri_prover = FRIFolder::new(fri_params, ntt, merkle_prover, codeword, committed)?;
297
298	let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
299
300	for round in 0..n_rounds {
301		let _span =
302			tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
303				.entered();
304
305		let bivariate_sumcheck_span = tracing::debug_span!(
306			"[step] Bivariate Sumcheck",
307			phase = "piop_compiler",
308			round = round,
309			perfetto_category = "phase.sub"
310		)
311		.entered();
312		let provers_dimensions_data =
313			SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
314		let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
315			"[task] (PIOP Compiler) Calculate Coeffs",
316			phase = "piop_compiler",
317			round = round,
318			perfetto_category = "task.main",
319			dimensions_data = ?provers_dimensions_data,
320		)
321		.entered();
322		sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
323		drop(bivariate_sumcheck_calculate_coeffs_span);
324
325		let challenge = transcript.sample();
326		let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
327			"[task] (PIOP Compiler) Fold (All Rounds)",
328			phase = "piop_compiler",
329			round = round,
330			perfetto_category = "task.main",
331			dimensions_data = ?provers_dimensions_data,
332		)
333		.entered();
334		sumcheck_batch_prover.receive_challenge(challenge)?;
335		drop(bivariate_sumcheck_all_folds_span);
336		drop(bivariate_sumcheck_span);
337
338		let dimensions_data = FriFoldRoundsData::new(
339			round,
340			fri_params.log_batch_size(),
341			fri_prover.current_codeword_len(),
342		);
343		let fri_fold_rounds_span = tracing::debug_span!(
344			"[step] FRI Fold Rounds",
345			phase = "piop_compiler",
346			round = round,
347			perfetto_category = "phase.sub",
348			dimensions_data = ?dimensions_data,
349		)
350		.entered();
351		match fri_prover.execute_fold_round(challenge)? {
352			FoldRoundOutput::NoCommitment => {}
353			FoldRoundOutput::Commitment(round_commitment) => {
354				transcript.message().write(&round_commitment);
355			}
356		}
357		drop(fri_fold_rounds_span);
358	}
359
360	sumcheck_batch_prover.finish(&mut transcript.message())?;
361	fri_prover.finish_proof(transcript)?;
362	Ok(())
363}
364
365pub fn validate_sumcheck_witness<F, P, M>(
366	committed_multilins: &[M],
367	transparent_multilins: &[M],
368	claims: &[PIOPSumcheckClaim<F>],
369) -> Result<(), Error>
370where
371	F: TowerField,
372	P: PackedField<Scalar = F>,
373	M: MultilinearPoly<P> + Send + Sync,
374{
375	let packed_committed = committed_multilins
376		.iter()
377		.enumerate()
378		.map(|(i, unpacked_committed)| packed_committed(i, unpacked_committed))
379		.collect::<Result<Vec<_>, _>>()?;
380
381	for (i, claim) in claims.iter().enumerate() {
382		let committed = &packed_committed[claim.committed];
383		if committed.n_vars() != claim.n_vars {
384			bail!(sumcheck::Error::NumberOfVariablesMismatch);
385		}
386
387		let transparent = &transparent_multilins[claim.transparent];
388		if transparent.n_vars() != claim.n_vars {
389			bail!(sumcheck::Error::NumberOfVariablesMismatch);
390		}
391
392		let sum = (0..(1 << claim.n_vars))
393			.into_par_iter()
394			.map(|j| {
395				let committed_eval = committed
396					.evaluate_on_hypercube(j)
397					.expect("j is less than 1 << n_vars; committed.n_vars is checked above");
398				let transparent_eval = transparent
399					.evaluate_on_hypercube(j)
400					.expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
401				committed_eval * transparent_eval
402			})
403			.sum::<F>();
404
405		if sum != claim.sum {
406			bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
407				composition_index: i,
408			});
409		}
410	}
411	Ok(())
412}
413
414/// Creates a multilinear extension of the packed evalations of a small-field multilinear.
415///
416/// Given a multilinear $P \in T_{\iota}[X_0, \ldots, X_{n-1}]$, this creates the multilinear
417/// extension $\hat{P} \in T_{\tau}[X_0, \ldots, X_{n - \kappa - 1}]$. In the case where
418/// $n < \kappa$, which is when a polynomial is too full to have even a single packed evaluation,
419/// the polynomial is extended by padding with more variables, which corresponds to repeating its
420/// subcube evaluations.
421fn packed_committed<F, P, M>(
422	id: usize,
423	unpacked_committed: &M,
424) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
425where
426	F: TowerField,
427	P: PackedField<Scalar = F>,
428	M: MultilinearPoly<P>,
429{
430	let unpacked_n_vars = unpacked_committed.n_vars();
431	let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
432		let packed_eval = padded_packed_eval(unpacked_committed);
433		MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
434	} else {
435		let packed_evals = unpacked_committed
436			.packed_evals()
437			.ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
438
439		MultilinearExtension::new(
440			unpacked_n_vars - unpacked_committed.log_extension_degree(),
441			Cow::Borrowed(packed_evals),
442		)
443	}?;
444	Ok(packed_committed)
445}
446
447#[inline]
448fn padded_packed_eval<F, P, M>(multilin: &M) -> F
449where
450	F: TowerField,
451	P: PackedField<Scalar = F>,
452	M: MultilinearPoly<P>,
453{
454	let n_vars = multilin.n_vars();
455	let kappa = multilin.log_extension_degree();
456	assert!(n_vars < kappa);
457
458	(0..1 << kappa)
459		.map(|i| {
460			let iota = F::TOWER_LEVEL - kappa;
461			let scalar = <F as TowerField>::basis(iota, i)
462				.expect("i is in range 0..1 << log_extension_degree");
463			multilin
464				.evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
465				.expect("i is in range 0..1 << n_vars")
466		})
467		.sum()
468}
469
470#[cfg(test)]
471mod tests {
472	use std::iter::repeat_with;
473
474	use binius_field::PackedBinaryField2x128b;
475	use rand::{rngs::StdRng, SeedableRng};
476
477	use super::*;
478
479	#[test]
480	fn test_merge_multilins() {
481		let mut rng = StdRng::seed_from_u64(0);
482
483		let multilins = (0usize..8)
484			.map(|n_vars| {
485				let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
486					.take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
487					.collect::<Vec<_>>();
488
489				MultilinearExtension::new(n_vars, data).unwrap()
490			})
491			.collect::<Vec<_>>();
492		let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
493		let mut buffer =
494			vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
495		merge_multilins(&multilins, &mut buffer);
496
497		let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
498		let mut offset = 0;
499		for multilin in multilins.iter().rev() {
500			let scalars = &scalars[offset..];
501			for (i, v) in PackedField::iter_slice(multilin.evals())
502				.take(1 << multilin.n_vars())
503				.enumerate()
504			{
505				assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
506			}
507			offset += 1 << multilin.n_vars();
508		}
509	}
510}