binius_core/piop/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	packed::{get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked},
5	BinaryField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{
9	EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
10	MultilinearPoly,
11};
12use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
13use binius_ntt::{NTTOptions, ThreadingSettings};
14use binius_utils::{
15	bail, checked_arithmetics::checked_log_2, sorting::is_sorted_ascending, SerializeBytes,
16};
17use either::Either;
18use itertools::{chain, Itertools};
19
20use super::{
21	error::Error,
22	verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
23};
24use crate::{
25	fiat_shamir::{CanSample, Challenger},
26	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
27	piop::CommitMeta,
28	protocols::{
29		fri,
30		fri::{FRIFolder, FRIParams, FoldRoundOutput},
31		sumcheck,
32		sumcheck::{
33			immediate_switchover_heuristic,
34			prove::{
35				front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
36				SumcheckProver,
37			},
38		},
39	},
40	reed_solomon::reed_solomon::ReedSolomonCode,
41	transcript::ProverTranscript,
42};
43
44/// Reorders the scalars in a slice of packed field elements by reversing the bits of their indices.
45/// TODO: investigate if we can optimize this.
46fn reverse_slice_index_bits<P: PackedField>(slice: &mut [P]) {
47	let log_len = checked_log_2(slice.len()) + P::LOG_WIDTH;
48	for i in 0..slice.len() << P::LOG_WIDTH {
49		let bit_reversed_index = i
50			.reverse_bits()
51			.wrapping_shr((usize::BITS as usize - log_len) as _);
52		if i < bit_reversed_index {
53			// Safety: `i` and `j` are guaranteed to be in bounds of the slice
54			unsafe {
55				let tmp = get_packed_slice_unchecked(slice, i);
56				set_packed_slice_unchecked(
57					slice,
58					i,
59					get_packed_slice_unchecked(slice, bit_reversed_index),
60				);
61				set_packed_slice_unchecked(slice, bit_reversed_index, tmp);
62			}
63		}
64	}
65}
66
67// ## Preconditions
68//
69// * all multilinears in `multilins` have at least log_extension_degree packed variables
70// * all multilinears in `multilins` have `packed_evals()` is Some
71// * multilinears are sorted in ascending order by number of packed variables
72// * `message_buffer` is initialized to all zeros
73// * `message_buffer` is larger than the total number of scalars in the multilinears
74fn merge_multilins<P, M>(multilins: &[M], message_buffer: &mut [P])
75where
76	P: PackedField,
77	M: MultilinearPoly<P>,
78{
79	let mut mle_iter = multilins.iter().rev();
80
81	// First copy all the polynomials where the number of elements is a multiple of the packing
82	// width.
83	let get_n_packed_vars = |mle: &M| mle.n_vars() - mle.log_extension_degree();
84	let mut full_packed_mles = Vec::new(); // (evals, corresponding buffer where to copy)
85	let mut remaining_buffer = message_buffer;
86	for mle in mle_iter.peeking_take_while(|mle| get_n_packed_vars(mle) >= P::LOG_WIDTH) {
87		let evals = mle
88			.packed_evals()
89			.expect("guaranteed by function precondition");
90		let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
91		full_packed_mles.push((evals, chunk));
92		remaining_buffer = rest;
93	}
94	full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
95		chunk.copy_from_slice(evals);
96		reverse_slice_index_bits(chunk);
97	});
98
99	// Now copy scalars from the remaining multilinears, which have too few elements to copy full
100	// packed elements.
101	let mut scalar_offset = 0;
102	for mle in mle_iter {
103		let evals = mle
104			.packed_evals()
105			.expect("guaranteed by function precondition");
106		let packed_eval = evals[0];
107		for i in 0..1 << mle.n_vars() {
108			set_packed_slice(remaining_buffer, scalar_offset, packed_eval.get(i));
109			scalar_offset += 1;
110		}
111	}
112}
113
114/// Commits a batch of multilinear polynomials.
115///
116/// The multilinears this function accepts as arguments may be defined over subfields of `F`. In
117/// this case, we commit to these multilinears by instead committing to their "packed"
118/// multilinears. These are the multilinear extensions of their packed coefficients over subcubes
119/// of the size of the extension degree.
120///
121/// ## Arguments
122///
123/// * `fri_params` - the FRI parameters for the commitment opening protocol
124/// * `merkle_prover` - the Merkle tree prover used in FRI
125/// * `multilins` - a batch of multilinear polynomials to commit. The multilinears provided may be
126///     defined over subfields of `F`. They must be in ascending order by the number of variables
127///     in the packed multilinear (ie. number of variables minus log extension degree).
128#[tracing::instrument("piop::commit", skip_all)]
129pub fn commit<F, FEncode, P, M, MTScheme, MTProver>(
130	fri_params: &FRIParams<F, FEncode>,
131	merkle_prover: &MTProver,
132	multilins: &[M],
133) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
134where
135	F: BinaryField,
136	FEncode: BinaryField,
137	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
138	M: MultilinearPoly<P>,
139	MTScheme: MerkleTreeScheme<F>,
140	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
141{
142	for (i, multilin) in multilins.iter().enumerate() {
143		if multilin.n_vars() < multilin.log_extension_degree() {
144			return Err(Error::OracleTooSmall {
145				// i is not an OracleId, but whatever, that's a problem for whoever has to debug
146				// this
147				id: i,
148				n_vars: multilin.n_vars(),
149				min_vars: multilin.log_extension_degree(),
150			});
151		}
152		if multilin.packed_evals().is_none() {
153			return Err(Error::CommittedPackedEvaluationsMissing { id: i });
154		}
155	}
156
157	let n_packed_vars = multilins
158		.iter()
159		.map(|multilin| multilin.n_vars() - multilin.log_extension_degree());
160	if !is_sorted_ascending(n_packed_vars) {
161		return Err(Error::CommittedsNotSorted);
162	}
163
164	// TODO: this should be passed in to avoid recomputing twiddles
165	let rs_code = ReedSolomonCode::new(
166		fri_params.rs_code().log_dim(),
167		fri_params.rs_code().log_inv_rate(),
168		&NTTOptions {
169			precompute_twiddles: true,
170			thread_settings: ThreadingSettings::MultithreadedDefault,
171		},
172	)?;
173	let output =
174		fri::commit_interleaved_with(&rs_code, fri_params, merkle_prover, |message_buffer| {
175			merge_multilins(multilins, message_buffer)
176		})?;
177
178	Ok(output)
179}
180
181/// Proves a batch of sumcheck claims that are products of committed polynomials from a committed
182/// batch and transparent polynomials.
183///
184/// The arguments corresponding to the committed multilinears must be the output of [`commit`].
185#[allow(clippy::too_many_arguments)]
186#[tracing::instrument("piop::prove", skip_all)]
187pub fn prove<F, FDomain, FEncode, P, M, DomainFactory, MTScheme, MTProver, Challenger_, Backend>(
188	fri_params: &FRIParams<F, FEncode>,
189	merkle_prover: &MTProver,
190	domain_factory: DomainFactory,
191	commit_meta: &CommitMeta,
192	committed: MTProver::Committed,
193	codeword: &[P],
194	committed_multilins: &[M],
195	transparent_multilins: &[M],
196	claims: &[PIOPSumcheckClaim<F>],
197	transcript: &mut ProverTranscript<Challenger_>,
198	backend: &Backend,
199) -> Result<(), Error>
200where
201	F: TowerField,
202	FDomain: Field,
203	FEncode: BinaryField,
204	P: PackedFieldIndexable<Scalar = F>
205		+ PackedExtension<F, PackedSubfield = P>
206		+ PackedExtension<FDomain>
207		+ PackedExtension<FEncode>,
208	M: MultilinearPoly<P> + Send + Sync,
209	DomainFactory: EvaluationDomainFactory<FDomain>,
210	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212	Challenger_: Challenger,
213	Backend: ComputationBackend,
214{
215	// Map of n_vars to sumcheck claim descriptions
216	let sumcheck_claim_descs = make_sumcheck_claim_descs(
217		commit_meta,
218		transparent_multilins.iter().map(|poly| poly.n_vars()),
219		claims,
220	)?;
221
222	// The committed multilinears provided by argument are committed *small field* multilinears.
223	// Create multilinears representing the packed polynomials here. Eventually, we would like to
224	// refactor the calling code so that the PIOP only handles *big field* multilinear witnesses.
225	let packed_committed_multilins = committed_multilins
226		.iter()
227		.enumerate()
228		.map(|(i, committed_multilin)| {
229			let packed_evals = committed_multilin
230				.packed_evals()
231				.ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
232			let packed_multilin = MultilinearExtension::from_values_slice(packed_evals)?;
233			Ok::<_, Error>(MLEDirectAdapter::from(packed_multilin))
234		})
235		.collect::<Result<Vec<_>, _>>()?;
236
237	let non_empty_sumcheck_descs = sumcheck_claim_descs
238		.iter()
239		.enumerate()
240		.filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
241	let sumcheck_provers = non_empty_sumcheck_descs
242		.clone()
243		.map(|(_n_vars, desc)| {
244			let multilins = chain!(
245				packed_committed_multilins[desc.committed_indices.clone()]
246					.iter()
247					.map(Either::Left),
248				transparent_multilins[desc.transparent_indices.clone()]
249					.iter()
250					.map(Either::Right),
251			)
252			.collect::<Vec<_>>();
253			RegularSumcheckProver::new(
254				EvaluationOrder::HighToLow,
255				multilins,
256				desc.composite_sums.iter().cloned(),
257				&domain_factory,
258				immediate_switchover_heuristic,
259				backend,
260			)
261		})
262		.collect::<Result<Vec<_>, _>>()?;
263
264	prove_interleaved_fri_sumcheck(
265		commit_meta.total_vars(),
266		fri_params,
267		merkle_prover,
268		sumcheck_provers,
269		codeword,
270		&committed,
271		transcript,
272	)?;
273
274	Ok(())
275}
276
277fn prove_interleaved_fri_sumcheck<F, FEncode, P, MTScheme, MTProver, Challenger_>(
278	n_rounds: usize,
279	fri_params: &FRIParams<F, FEncode>,
280	merkle_prover: &MTProver,
281	sumcheck_provers: Vec<impl SumcheckProver<F>>,
282	codeword: &[P],
283	committed: &MTProver::Committed,
284	transcript: &mut ProverTranscript<Challenger_>,
285) -> Result<(), Error>
286where
287	F: TowerField,
288	FEncode: BinaryField,
289	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
290	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
291	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
292	Challenger_: Challenger,
293{
294	let mut fri_prover =
295		FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?;
296
297	let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
298
299	for _ in 0..n_rounds {
300		sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
301		let challenge = transcript.sample();
302		sumcheck_batch_prover.receive_challenge(challenge)?;
303
304		match fri_prover.execute_fold_round(challenge)? {
305			FoldRoundOutput::NoCommitment => {}
306			FoldRoundOutput::Commitment(round_commitment) => {
307				transcript.message().write(&round_commitment);
308			}
309		}
310	}
311
312	sumcheck_batch_prover.finish(&mut transcript.message())?;
313	fri_prover.finish_proof(transcript)?;
314	Ok(())
315}
316
317pub fn validate_sumcheck_witness<F, P, M>(
318	committed_multilins: &[M],
319	transparent_multilins: &[M],
320	claims: &[PIOPSumcheckClaim<F>],
321) -> Result<(), Error>
322where
323	F: TowerField,
324	P: PackedField<Scalar = F>,
325	M: MultilinearPoly<P> + Send + Sync,
326{
327	let packed_committed = committed_multilins
328		.iter()
329		.enumerate()
330		.map(|(i, unpacked_committed)| {
331			let packed_evals = unpacked_committed
332				.packed_evals()
333				.ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
334			let packed_committed = MultilinearExtension::from_values_slice(packed_evals)?;
335			Ok::<_, Error>(packed_committed)
336		})
337		.collect::<Result<Vec<_>, _>>()?;
338
339	for (i, claim) in claims.iter().enumerate() {
340		let committed = &packed_committed[claim.committed];
341		if committed.n_vars() != claim.n_vars {
342			bail!(sumcheck::Error::NumberOfVariablesMismatch);
343		}
344
345		let transparent = &transparent_multilins[claim.transparent];
346		if transparent.n_vars() != claim.n_vars {
347			bail!(sumcheck::Error::NumberOfVariablesMismatch);
348		}
349
350		let sum = (0..(1 << claim.n_vars))
351			.into_par_iter()
352			.map(|j| {
353				let committed_eval = committed
354					.evaluate_on_hypercube(j)
355					.expect("j is less than 1 << n_vars; committed.n_vars is checked above");
356				let transparent_eval = transparent
357					.evaluate_on_hypercube(j)
358					.expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
359				committed_eval * transparent_eval
360			})
361			.sum::<F>();
362
363		if sum != claim.sum {
364			bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
365				composition_index: i,
366			});
367		}
368	}
369	Ok(())
370}