1use std::{borrow::Cow, ops::Deref};
4
5use binius_field::{
6 BinaryField, Field, PackedExtension, PackedField, TowerField, packed::PackedSliceMut,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{
10 EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
11 MultilinearPoly,
12};
13use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
14use binius_ntt::AdditiveNTT;
15use binius_utils::{
16 SerializeBytes, bail,
17 checked_arithmetics::checked_log_2,
18 random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
19 sorting::is_sorted_ascending,
20};
21use either::Either;
22use itertools::{Itertools, chain};
23
24use super::{
25 error::Error,
26 verify::{PIOPSumcheckClaim, make_sumcheck_claim_descs},
27};
28use crate::{
29 fiat_shamir::{CanSample, Challenger},
30 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
31 oracle::OracleId,
32 piop::{
33 CommitMeta,
34 logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
35 },
36 protocols::{
37 fri::{self, FRIFolder, FRIParams, FoldRoundOutput},
38 sumcheck::{
39 self, immediate_switchover_heuristic,
40 prove::{
41 RegularSumcheckProver, SumcheckProver,
42 front_loaded::BatchProver as SumcheckBatchProver,
43 },
44 },
45 },
46 transcript::ProverTranscript,
47};
48
49#[inline(always)]
50fn reverse_bits(x: usize, log_len: usize) -> usize {
51 x.reverse_bits()
52 .wrapping_shr((usize::BITS as usize - log_len) as _)
53}
54
55fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
58 let log_len = checked_log_2(collection.len());
59 for i in 0..collection.len() {
60 let bit_reversed_index = reverse_bits(i, log_len);
61 if i < bit_reversed_index {
62 unsafe {
64 let tmp = collection.get_unchecked(i);
65 collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
66 collection.set_unchecked(bit_reversed_index, tmp);
67 }
68 }
69 }
70}
71
72fn merge_multilins<F, P, Data>(
80 multilins: &[MultilinearExtension<P, Data>],
81 message_buffer: &mut [P],
82) where
83 F: TowerField,
84 P: PackedField<Scalar = F>,
85 Data: Deref<Target = [P]>,
86{
87 let mut mle_iter = multilins.iter().rev();
88
89 let mut full_packed_mles = Vec::new(); let mut remaining_buffer = message_buffer;
93 for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
94 let evals = mle.evals();
95 let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
96 full_packed_mles.push((evals, chunk));
97 remaining_buffer = rest;
98 }
99 full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
100 chunk.copy_from_slice(evals);
101 reverse_index_bits(&mut PackedSliceMut::new(chunk));
102 });
103
104 let mut scalar_offset = 0;
107 let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
108 for mle in mle_iter {
109 let packed_eval = mle.evals()[0];
110 let len = 1 << mle.n_vars();
111 let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
112 for i in 0..len {
113 packed_chunk.set(i, packed_eval.get(i));
114 }
115 reverse_index_bits(&mut packed_chunk);
116
117 scalar_offset += len;
118 }
119}
120
121pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
136 fri_params: &FRIParams<F, FEncode>,
137 ntt: &NTT,
138 merkle_prover: &MTProver,
139 multilins: &[M],
140) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
141where
142 F: TowerField,
143 FEncode: BinaryField,
144 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
145 M: MultilinearPoly<P>,
146 NTT: AdditiveNTT<FEncode> + Sync,
147 MTScheme: MerkleTreeScheme<F>,
148 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
149{
150 let packed_multilins = multilins
151 .iter()
152 .enumerate()
153 .map(|(i, unpacked_committed)| {
154 packed_committed(OracleId::from_index(i), unpacked_committed)
155 })
156 .collect::<Result<Vec<_>, _>>()?;
157 if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
158 return Err(Error::CommittedsNotSorted);
159 }
160
161 let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
162 merge_multilins(&packed_multilins, message_buffer)
163 })?;
164
165 Ok(output)
166}
167
168#[allow(clippy::too_many_arguments)]
173pub fn prove<
174 F,
175 FDomain,
176 FEncode,
177 P,
178 M,
179 NTT,
180 DomainFactory,
181 MTScheme,
182 MTProver,
183 Challenger_,
184 Backend,
185>(
186 fri_params: &FRIParams<F, FEncode>,
187 ntt: &NTT,
188 merkle_prover: &MTProver,
189 domain_factory: DomainFactory,
190 commit_meta: &CommitMeta,
191 committed: MTProver::Committed,
192 codeword: &[P],
193 committed_multilins: &[M],
194 transparent_multilins: &[M],
195 claims: &[PIOPSumcheckClaim<F>],
196 transcript: &mut ProverTranscript<Challenger_>,
197 backend: &Backend,
198) -> Result<(), Error>
199where
200 F: TowerField,
201 FDomain: Field,
202 FEncode: BinaryField,
203 P: PackedField<Scalar = F>
204 + PackedExtension<F, PackedSubfield = P>
205 + PackedExtension<FDomain>
206 + PackedExtension<FEncode>,
207 M: MultilinearPoly<P> + Send + Sync,
208 NTT: AdditiveNTT<FEncode> + Sync,
209 DomainFactory: EvaluationDomainFactory<FDomain>,
210 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212 Challenger_: Challenger,
213 Backend: ComputationBackend,
214{
215 let sumcheck_claim_descs = make_sumcheck_claim_descs(
217 commit_meta,
218 transparent_multilins.iter().map(|poly| poly.n_vars()),
219 claims,
220 )?;
221
222 let packed_committed_multilins = committed_multilins
226 .iter()
227 .enumerate()
228 .map(|(i, unpacked_committed)| {
229 packed_committed(OracleId::from_index(i), unpacked_committed)
230 .map(MLEDirectAdapter::from)
231 })
232 .collect::<Result<Vec<_>, _>>()?;
233
234 let non_empty_sumcheck_descs = sumcheck_claim_descs
235 .iter()
236 .enumerate()
237 .filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
241 let sumcheck_provers = non_empty_sumcheck_descs
242 .map(|(_n_vars, desc)| {
243 let multilins = chain!(
244 packed_committed_multilins[desc.committed_indices.clone()]
245 .iter()
246 .map(Either::Left),
247 transparent_multilins[desc.transparent_indices.clone()]
248 .iter()
249 .map(Either::Right),
250 )
251 .collect::<Vec<_>>();
252 RegularSumcheckProver::new(
253 EvaluationOrder::HighToLow,
254 multilins,
255 desc.composite_sums.iter().cloned(),
256 &domain_factory,
257 immediate_switchover_heuristic,
258 backend,
259 )
260 })
261 .collect::<Result<Vec<_>, _>>()?;
262
263 prove_interleaved_fri_sumcheck(
264 commit_meta.total_vars(),
265 fri_params,
266 ntt,
267 merkle_prover,
268 sumcheck_provers,
269 codeword,
270 &committed,
271 transcript,
272 )?;
273
274 Ok(())
275}
276
277#[allow(clippy::too_many_arguments)]
278fn prove_interleaved_fri_sumcheck<F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
279 n_rounds: usize,
280 fri_params: &FRIParams<F, FEncode>,
281 ntt: &NTT,
282 merkle_prover: &MTProver,
283 sumcheck_provers: Vec<impl SumcheckProver<F>>,
284 codeword: &[P],
285 committed: &MTProver::Committed,
286 transcript: &mut ProverTranscript<Challenger_>,
287) -> Result<(), Error>
288where
289 F: TowerField,
290 FEncode: BinaryField,
291 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
292 NTT: AdditiveNTT<FEncode> + Sync,
293 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
294 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
295 Challenger_: Challenger,
296{
297 let mut fri_prover = FRIFolder::new(fri_params, ntt, merkle_prover, codeword, committed)?;
298
299 let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
300
301 for round in 0..n_rounds {
302 let _span =
303 tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
304 .entered();
305
306 let bivariate_sumcheck_span = tracing::debug_span!(
307 "[step] Bivariate Sumcheck",
308 phase = "piop_compiler",
309 round = round,
310 perfetto_category = "phase.sub"
311 )
312 .entered();
313 let provers_dimensions_data =
314 SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
315 let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
316 "[task] (PIOP Compiler) Calculate Coeffs",
317 phase = "piop_compiler",
318 round = round,
319 perfetto_category = "task.main",
320 dimensions_data = ?provers_dimensions_data,
321 )
322 .entered();
323 sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
324 drop(bivariate_sumcheck_calculate_coeffs_span);
325
326 let challenge = transcript.sample();
327 let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
328 "[task] (PIOP Compiler) Fold (All Rounds)",
329 phase = "piop_compiler",
330 round = round,
331 perfetto_category = "task.main",
332 dimensions_data = ?provers_dimensions_data,
333 )
334 .entered();
335 sumcheck_batch_prover.receive_challenge(challenge)?;
336 drop(bivariate_sumcheck_all_folds_span);
337 drop(bivariate_sumcheck_span);
338
339 let dimensions_data = FriFoldRoundsData::new(
340 round,
341 fri_params.log_batch_size(),
342 fri_prover.current_codeword_len(),
343 );
344 let fri_fold_rounds_span = tracing::debug_span!(
345 "[step] FRI Fold Rounds",
346 phase = "piop_compiler",
347 round = round,
348 perfetto_category = "phase.sub",
349 ?dimensions_data,
350 )
351 .entered();
352 match fri_prover.execute_fold_round(challenge)? {
353 FoldRoundOutput::NoCommitment => {}
354 FoldRoundOutput::Commitment(round_commitment) => {
355 transcript.message().write(&round_commitment);
356 }
357 }
358 drop(fri_fold_rounds_span);
359 }
360
361 sumcheck_batch_prover.finish(&mut transcript.message())?;
362 fri_prover.finish_proof(transcript)?;
363 Ok(())
364}
365
366pub fn validate_sumcheck_witness<F, P, M>(
367 committed_multilins: &[M],
368 transparent_multilins: &[M],
369 claims: &[PIOPSumcheckClaim<F>],
370) -> Result<(), Error>
371where
372 F: TowerField,
373 P: PackedField<Scalar = F>,
374 M: MultilinearPoly<P> + Send + Sync,
375{
376 let packed_committed = committed_multilins
377 .iter()
378 .enumerate()
379 .map(|(i, unpacked_committed)| {
380 packed_committed(OracleId::from_index(i), unpacked_committed)
381 })
382 .collect::<Result<Vec<_>, _>>()?;
383
384 for (i, claim) in claims.iter().enumerate() {
385 let committed = &packed_committed[claim.committed];
386 if committed.n_vars() != claim.n_vars {
387 bail!(sumcheck::Error::NumberOfVariablesMismatch);
388 }
389
390 let transparent = &transparent_multilins[claim.transparent];
391 if transparent.n_vars() != claim.n_vars {
392 bail!(sumcheck::Error::NumberOfVariablesMismatch);
393 }
394
395 let sum = (0..(1 << claim.n_vars))
396 .into_par_iter()
397 .map(|j| {
398 let committed_eval = committed
399 .evaluate_on_hypercube(j)
400 .expect("j is less than 1 << n_vars; committed.n_vars is checked above");
401 let transparent_eval = transparent
402 .evaluate_on_hypercube(j)
403 .expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
404 committed_eval * transparent_eval
405 })
406 .sum::<F>();
407
408 if sum != claim.sum {
409 bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
410 composition_index: i,
411 });
412 }
413 }
414 Ok(())
415}
416
417fn packed_committed<F, P, M>(
425 id: OracleId,
426 unpacked_committed: &M,
427) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
428where
429 F: TowerField,
430 P: PackedField<Scalar = F>,
431 M: MultilinearPoly<P>,
432{
433 let unpacked_n_vars = unpacked_committed.n_vars();
434 let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
435 let packed_eval = padded_packed_eval(unpacked_committed);
436 MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
437 } else {
438 let packed_evals = unpacked_committed
439 .packed_evals()
440 .ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
441
442 MultilinearExtension::new(
443 unpacked_n_vars - unpacked_committed.log_extension_degree(),
444 Cow::Borrowed(packed_evals),
445 )
446 }?;
447 Ok(packed_committed)
448}
449
450#[inline]
451fn padded_packed_eval<F, P, M>(multilin: &M) -> F
452where
453 F: TowerField,
454 P: PackedField<Scalar = F>,
455 M: MultilinearPoly<P>,
456{
457 let n_vars = multilin.n_vars();
458 let kappa = multilin.log_extension_degree();
459 assert!(n_vars < kappa);
460
461 (0..1 << kappa)
462 .map(|i| {
463 let iota = F::TOWER_LEVEL - kappa;
464 let scalar = <F as TowerField>::basis(iota, i)
465 .expect("i is in range 0..1 << log_extension_degree");
466 multilin
467 .evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
468 .expect("i is in range 0..1 << n_vars")
469 })
470 .sum()
471}
472
473#[cfg(test)]
474mod tests {
475 use std::iter::repeat_with;
476
477 use binius_field::PackedBinaryField2x128b;
478 use rand::{SeedableRng, rngs::StdRng};
479
480 use super::*;
481
482 #[test]
483 fn test_merge_multilins() {
484 let mut rng = StdRng::seed_from_u64(0);
485
486 let multilins = (0usize..8)
487 .map(|n_vars| {
488 let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
489 .take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
490 .collect::<Vec<_>>();
491
492 MultilinearExtension::new(n_vars, data).unwrap()
493 })
494 .collect::<Vec<_>>();
495 let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
496 let mut buffer =
497 vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
498 merge_multilins(&multilins, &mut buffer);
499
500 let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
501 let mut offset = 0;
502 for multilin in multilins.iter().rev() {
503 let scalars = &scalars[offset..];
504 for (i, v) in PackedField::iter_slice(multilin.evals())
505 .take(1 << multilin.n_vars())
506 .enumerate()
507 {
508 assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
509 }
510 offset += 1 << multilin.n_vars();
511 }
512 }
513}