binius_core/piop/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField,
5	PackedFieldIndexable, TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{
9	EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
10	MultilinearPoly,
11};
12use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
13use binius_ntt::{NTTOptions, ThreadingSettings};
14use binius_utils::{bail, sorting::is_sorted_ascending, SerializeBytes};
15use either::Either;
16use itertools::{chain, Itertools};
17
18use super::{
19	error::Error,
20	verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
21};
22use crate::{
23	fiat_shamir::{CanSample, Challenger},
24	merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
25	piop::CommitMeta,
26	protocols::{
27		fri,
28		fri::{FRIFolder, FRIParams, FoldRoundOutput},
29		sumcheck,
30		sumcheck::{
31			immediate_switchover_heuristic,
32			prove::{
33				front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
34				SumcheckProver,
35			},
36		},
37	},
38	reed_solomon::reed_solomon::ReedSolomonCode,
39	transcript::ProverTranscript,
40};
41
42// ## Preconditions
43//
44// * all multilinears in `multilins` have at least log_extension_degree packed variables
45// * all multilinears in `multilins` have `packed_evals()` is Some
46// * multilinears are sorted in ascending order by number of packed variables
47// * `message_buffer` is initialized to all zeros
48// * `message_buffer` is larger than the total number of scalars in the multilinears
49fn merge_multilins<P, M>(multilins: &[M], message_buffer: &mut [P])
50where
51	P: PackedField,
52	M: MultilinearPoly<P>,
53{
54	let mut mle_iter = multilins.iter().rev();
55
56	// First copy all the polynomials where the number of elements is a multiple of the packing
57	// width.
58	let get_n_packed_vars = |mle: &M| mle.n_vars() - mle.log_extension_degree();
59	let mut full_packed_mles = Vec::new(); // (evals, corresponding buffer where to copy)
60	let mut remaining_buffer = message_buffer;
61	for mle in mle_iter.peeking_take_while(|mle| get_n_packed_vars(mle) >= P::LOG_WIDTH) {
62		let evals = mle
63			.packed_evals()
64			.expect("guaranteed by function precondition");
65		let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
66		full_packed_mles.push((evals, chunk));
67		remaining_buffer = rest;
68	}
69	full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
70		chunk.copy_from_slice(evals);
71	});
72
73	// Now copy scalars from the remaining multilinears, which have too few elements to copy full
74	// packed elements.
75	let mut scalar_offset = 0;
76	for mle in mle_iter {
77		let evals = mle
78			.packed_evals()
79			.expect("guaranteed by function precondition");
80		let packed_eval = evals[0];
81		for i in 0..1 << mle.n_vars() {
82			set_packed_slice(remaining_buffer, scalar_offset, packed_eval.get(i));
83			scalar_offset += 1;
84		}
85	}
86}
87
88/// Commits a batch of multilinear polynomials.
89///
90/// The multilinears this function accepts as arguments may be defined over subfields of `F`. In
91/// this case, we commit to these multilinears by instead committing to their "packed"
92/// multilinears. These are the multilinear extensions of their packed coefficients over subcubes
93/// of the size of the extension degree.
94///
95/// ## Arguments
96///
97/// * `fri_params` - the FRI parameters for the commitment opening protocol
98/// * `merkle_prover` - the Merkle tree prover used in FRI
99/// * `multilins` - a batch of multilinear polynomials to commit. The multilinears provided may be
100///     defined over subfields of `F`. They must be in ascending order by the number of variables
101///     in the packed multilinear (ie. number of variables minus log extension degree).
102#[tracing::instrument("piop::commit", skip_all)]
103pub fn commit<F, FEncode, P, M, MTScheme, MTProver>(
104	fri_params: &FRIParams<F, FEncode>,
105	merkle_prover: &MTProver,
106	multilins: &[M],
107) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
108where
109	F: BinaryField,
110	FEncode: BinaryField,
111	P: PackedField<Scalar = F> + PackedExtension<FEncode>,
112	M: MultilinearPoly<P>,
113	MTScheme: MerkleTreeScheme<F>,
114	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
115{
116	for (i, multilin) in multilins.iter().enumerate() {
117		if multilin.n_vars() < multilin.log_extension_degree() {
118			return Err(Error::OracleTooSmall {
119				// i is not an OracleId, but whatever, that's a problem for whoever has to debug
120				// this
121				id: i,
122				min_vars: multilin.log_extension_degree(),
123			});
124		}
125		if multilin.packed_evals().is_none() {
126			return Err(Error::CommittedPackedEvaluationsMissing { id: i });
127		}
128	}
129
130	let n_packed_vars = multilins
131		.iter()
132		.map(|multilin| multilin.n_vars() - multilin.log_extension_degree());
133	if !is_sorted_ascending(n_packed_vars) {
134		return Err(Error::CommittedsNotSorted);
135	}
136
137	// TODO: this should be passed in to avoid recomputing twiddles
138	let rs_code = ReedSolomonCode::new(
139		fri_params.rs_code().log_dim(),
140		fri_params.rs_code().log_inv_rate(),
141		&NTTOptions {
142			precompute_twiddles: true,
143			thread_settings: ThreadingSettings::MultithreadedDefault,
144		},
145	)?;
146	let output =
147		fri::commit_interleaved_with(&rs_code, fri_params, merkle_prover, |message_buffer| {
148			merge_multilins(multilins, message_buffer)
149		})?;
150
151	Ok(output)
152}
153
154/// Proves a batch of sumcheck claims that are products of committed polynomials from a committed
155/// batch and transparent polynomials.
156///
157/// The arguments corresponding to the committed multilinears must be the output of [`commit`].
158#[allow(clippy::too_many_arguments)]
159#[tracing::instrument("piop::prove", skip_all)]
160pub fn prove<F, FDomain, FEncode, P, M, DomainFactory, MTScheme, MTProver, Challenger_, Backend>(
161	fri_params: &FRIParams<F, FEncode>,
162	merkle_prover: &MTProver,
163	domain_factory: DomainFactory,
164	commit_meta: &CommitMeta,
165	committed: MTProver::Committed,
166	codeword: &[P],
167	committed_multilins: &[M],
168	transparent_multilins: &[M],
169	claims: &[PIOPSumcheckClaim<F>],
170	transcript: &mut ProverTranscript<Challenger_>,
171	backend: &Backend,
172) -> Result<(), Error>
173where
174	F: TowerField,
175	FDomain: Field,
176	FEncode: BinaryField,
177	P: PackedFieldIndexable<Scalar = F>
178		+ PackedExtension<F, PackedSubfield = P>
179		+ PackedExtension<FDomain>
180		+ PackedExtension<FEncode>,
181	M: MultilinearPoly<P> + Send + Sync,
182	DomainFactory: EvaluationDomainFactory<FDomain>,
183	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
184	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
185	Challenger_: Challenger,
186	Backend: ComputationBackend,
187{
188	// Map of n_vars to sumcheck claim descriptions
189	let sumcheck_claim_descs = make_sumcheck_claim_descs(
190		commit_meta,
191		transparent_multilins.iter().map(|poly| poly.n_vars()),
192		claims,
193	)?;
194
195	// The committed multilinears provided by argument are committed *small field* multilinears.
196	// Create multilinears representing the packed polynomials here. Eventually, we would like to
197	// refactor the calling code so that the PIOP only handles *big field* multilinear witnesses.
198	let packed_committed_multilins = committed_multilins
199		.iter()
200		.enumerate()
201		.map(|(i, committed_multilin)| {
202			let packed_evals = committed_multilin
203				.packed_evals()
204				.ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
205			let packed_multilin = MultilinearExtension::from_values_slice(packed_evals)?;
206			Ok::<_, Error>(MLEDirectAdapter::from(packed_multilin))
207		})
208		.collect::<Result<Vec<_>, _>>()?;
209
210	let non_empty_sumcheck_descs = sumcheck_claim_descs
211		.iter()
212		.enumerate()
213		.filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
214	let sumcheck_provers = non_empty_sumcheck_descs
215		.clone()
216		.map(|(_n_vars, desc)| {
217			let multilins = chain!(
218				packed_committed_multilins[desc.committed_indices.clone()]
219					.iter()
220					.map(Either::Left),
221				transparent_multilins[desc.transparent_indices.clone()]
222					.iter()
223					.map(Either::Right),
224			)
225			.collect::<Vec<_>>();
226			RegularSumcheckProver::new(
227				EvaluationOrder::LowToHigh,
228				multilins,
229				desc.composite_sums.iter().cloned(),
230				&domain_factory,
231				immediate_switchover_heuristic,
232				backend,
233			)
234		})
235		.collect::<Result<Vec<_>, _>>()?;
236
237	prove_interleaved_fri_sumcheck(
238		commit_meta.total_vars(),
239		fri_params,
240		merkle_prover,
241		sumcheck_provers,
242		codeword,
243		&committed,
244		transcript,
245	)?;
246
247	Ok(())
248}
249
250fn prove_interleaved_fri_sumcheck<F, FEncode, P, MTScheme, MTProver, Challenger_>(
251	n_rounds: usize,
252	fri_params: &FRIParams<F, FEncode>,
253	merkle_prover: &MTProver,
254	sumcheck_provers: Vec<impl SumcheckProver<F>>,
255	codeword: &[P],
256	committed: &MTProver::Committed,
257	transcript: &mut ProverTranscript<Challenger_>,
258) -> Result<(), Error>
259where
260	F: TowerField,
261	FEncode: BinaryField,
262	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
263	MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
264	MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
265	Challenger_: Challenger,
266{
267	let mut fri_prover =
268		FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?;
269
270	let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
271
272	for _ in 0..n_rounds {
273		sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
274		let challenge = transcript.sample();
275		sumcheck_batch_prover.receive_challenge(challenge)?;
276
277		match fri_prover.execute_fold_round(challenge)? {
278			FoldRoundOutput::NoCommitment => {}
279			FoldRoundOutput::Commitment(round_commitment) => {
280				transcript.message().write(&round_commitment);
281			}
282		}
283	}
284
285	sumcheck_batch_prover.finish(&mut transcript.message())?;
286	fri_prover.finish_proof(transcript)?;
287	Ok(())
288}
289
290pub fn validate_sumcheck_witness<F, P, M>(
291	committed_multilins: &[M],
292	transparent_multilins: &[M],
293	claims: &[PIOPSumcheckClaim<F>],
294) -> Result<(), Error>
295where
296	F: TowerField,
297	P: PackedField<Scalar = F>,
298	M: MultilinearPoly<P> + Send + Sync,
299{
300	let packed_committed = committed_multilins
301		.iter()
302		.enumerate()
303		.map(|(i, unpacked_committed)| {
304			let packed_evals = unpacked_committed
305				.packed_evals()
306				.ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
307			let packed_committed = MultilinearExtension::from_values_slice(packed_evals)?;
308			Ok::<_, Error>(packed_committed)
309		})
310		.collect::<Result<Vec<_>, _>>()?;
311
312	for (i, claim) in claims.iter().enumerate() {
313		let committed = &packed_committed[claim.committed];
314		if committed.n_vars() != claim.n_vars {
315			bail!(sumcheck::Error::NumberOfVariablesMismatch);
316		}
317
318		let transparent = &transparent_multilins[claim.transparent];
319		if transparent.n_vars() != claim.n_vars {
320			bail!(sumcheck::Error::NumberOfVariablesMismatch);
321		}
322
323		let sum = (0..(1 << claim.n_vars))
324			.into_par_iter()
325			.map(|j| {
326				let committed_eval = committed
327					.evaluate_on_hypercube(j)
328					.expect("j is less than 1 << n_vars; committed.n_vars is checked above");
329				let transparent_eval = transparent
330					.evaluate_on_hypercube(j)
331					.expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
332				committed_eval * transparent_eval
333			})
334			.sum::<F>();
335
336		if sum != claim.sum {
337			bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
338				composition_index: i,
339			});
340		}
341	}
342	Ok(())
343}