1use std::{borrow::Cow, ops::Deref};
4
5use binius_compute::{
6 ComputeData, ComputeLayer, ComputeMemory, FSlice, SizedSlice, alloc::ComputeAllocator,
7 cpu::CpuMemory,
8};
9use binius_field::{
10 BinaryField, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
11 packed::PackedSliceMut,
12};
13use binius_math::{MLEDirectAdapter, MultilinearExtension, MultilinearPoly};
14use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
15use binius_ntt::AdditiveNTT;
16use binius_utils::{
17 SerializeBytes, bail,
18 checked_arithmetics::checked_log_2,
19 random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
20 sorting::is_sorted_ascending,
21};
22use bytemuck::zeroed_vec;
23use itertools::{Itertools, chain};
24
25use super::{
26 error::Error,
27 verify::{PIOPSumcheckClaim, make_sumcheck_claim_descs},
28};
29use crate::{
30 fiat_shamir::{CanSample, Challenger},
31 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
32 oracle::OracleId,
33 piop::{
34 CommitMeta,
35 logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
36 },
37 protocols::{
38 fri::{self, FRIFolder, FRIParams, FoldRoundOutput},
39 sumcheck::{
40 self, SumcheckClaim,
41 prove::{SumcheckProver, front_loaded::BatchProver as SumcheckBatchProver},
42 v3::bivariate_product::BivariateSumcheckProver,
43 },
44 },
45 transcript::ProverTranscript,
46};
47
48#[inline(always)]
49fn reverse_bits(x: usize, log_len: usize) -> usize {
50 x.reverse_bits()
51 .wrapping_shr((usize::BITS as usize - log_len) as _)
52}
53
54fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
57 let log_len = checked_log_2(collection.len());
58 for i in 0..collection.len() {
59 let bit_reversed_index = reverse_bits(i, log_len);
60 if i < bit_reversed_index {
61 unsafe {
63 let tmp = collection.get_unchecked(i);
64 collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
65 collection.set_unchecked(bit_reversed_index, tmp);
66 }
67 }
68 }
69}
70
71fn merge_multilins<F, P, Data>(
79 multilins: &[MultilinearExtension<P, Data>],
80 message_buffer: &mut [P],
81) where
82 F: TowerField,
83 P: PackedField<Scalar = F>,
84 Data: Deref<Target = [P]>,
85{
86 let mut mle_iter = multilins.iter().rev();
87
88 let mut full_packed_mles = Vec::new(); let mut remaining_buffer = message_buffer;
92 for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
93 let evals = mle.evals();
94 let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
95 full_packed_mles.push((evals, chunk));
96 remaining_buffer = rest;
97 }
98 full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
99 chunk.copy_from_slice(evals);
100 reverse_index_bits(&mut PackedSliceMut::new(chunk));
101 });
102
103 let mut scalar_offset = 0;
106 let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
107 for mle in mle_iter {
108 let packed_eval = mle.evals()[0];
109 let len = 1 << mle.n_vars();
110 let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
111 for i in 0..len {
112 packed_chunk.set(i, packed_eval.get(i));
113 }
114 reverse_index_bits(&mut packed_chunk);
115
116 scalar_offset += len;
117 }
118}
119
120pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
135 fri_params: &FRIParams<F, FEncode>,
136 ntt: &NTT,
137 merkle_prover: &MTProver,
138 multilins: &[M],
139) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
140where
141 F: TowerField,
142 FEncode: BinaryField,
143 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
144 M: MultilinearPoly<P>,
145 NTT: AdditiveNTT<FEncode> + Sync,
146 MTScheme: MerkleTreeScheme<F>,
147 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
148{
149 let packed_multilins = multilins
150 .iter()
151 .enumerate()
152 .map(|(i, unpacked_committed)| {
153 packed_committed(OracleId::from_index(i), unpacked_committed)
154 })
155 .collect::<Result<Vec<_>, _>>()?;
156 if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
157 return Err(Error::CommittedsNotSorted);
158 }
159
160 let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
161 merge_multilins(&packed_multilins, message_buffer)
162 })?;
163
164 Ok(output)
165}
166
167#[allow(clippy::too_many_arguments)]
172pub fn prove<
173 Hal,
174 F,
175 FEncode,
176 P,
177 M,
178 NTT,
179 MTScheme,
180 MTProver,
181 Challenger_,
182 HostComputeAllocatorType,
183 DeviceComputeAllocatorType,
184>(
185 compute_data: &ComputeData<F, Hal, HostComputeAllocatorType, DeviceComputeAllocatorType>,
186 fri_params: &FRIParams<F, FEncode>,
187 ntt: &NTT,
188 merkle_prover: &MTProver,
189 commit_meta: &CommitMeta,
190 committed: MTProver::Committed,
191 codeword: &[P],
192 committed_multilins: &[M],
193 transparent_multilins: Vec<FSlice<'_, F, Hal>>,
194 claims: &[PIOPSumcheckClaim<F>],
195 transcript: &mut ProverTranscript<Challenger_>,
196) -> Result<(), Error>
197where
198 F: TowerField,
199 FEncode: BinaryField,
200 P: PackedField<Scalar = F>
201 + PackedExtension<F, PackedSubfield = P>
202 + PackedExtension<FEncode>
203 + PackedFieldIndexable<Scalar = F>,
204 M: MultilinearPoly<P> + Send + Sync,
205 NTT: AdditiveNTT<FEncode> + Sync,
206 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
207 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
208 Challenger_: Challenger,
209 Hal: ComputeLayer<F>,
210 HostComputeAllocatorType: ComputeAllocator<F, CpuMemory>,
211 DeviceComputeAllocatorType: ComputeAllocator<F, Hal::DevMem>,
212{
213 let host_alloc = &compute_data.host_alloc;
214 let dev_alloc = &compute_data.dev_alloc;
215 let hal = compute_data.hal;
216
217 let sumcheck_claim_descs = make_sumcheck_claim_descs(
218 commit_meta,
219 transparent_multilins
220 .iter()
221 .map(|poly| checked_log_2(poly.len())),
222 claims,
223 )?;
224
225 let packed_committed_multilins = committed_multilins
229 .iter()
230 .enumerate()
231 .map(|(i, unpacked_committed)| {
232 packed_committed(OracleId::from_index(i), unpacked_committed)
233 .map(MLEDirectAdapter::from)
234 })
235 .collect::<Result<Vec<_>, _>>()?;
236
237 let copy_span = tracing::debug_span!(
238 "[task] Copy polynomials to device memory",
239 phase = "piop_compiler",
240 perfetto_category = "phase.sub",
241 )
242 .entered();
243 let packed_committed_fslices = packed_committed_multilins
244 .iter()
245 .map(|packed_committed_multilin| {
246 let hypercube_evals = packed_committed_multilin
247 .packed_evals()
248 .expect("Prover should always populate witnesses");
249 let unpacked_hypercube_evals = P::unpack_scalars(hypercube_evals);
250 let mut allocated_mem = dev_alloc.alloc(1 << packed_committed_multilin.n_vars())?;
251
252 hal.copy_h2d(
253 &unpacked_hypercube_evals[0..1 << packed_committed_multilin.n_vars()],
254 &mut allocated_mem,
255 )?;
256 Ok(Hal::DevMem::to_const(allocated_mem))
257 })
258 .collect::<Result<Vec<_>, Error>>()?;
259
260 drop(copy_span);
261
262 let non_empty_sumcheck_descs = sumcheck_claim_descs
263 .iter()
264 .enumerate()
265 .filter(|(_, desc)| !desc.committed_indices.is_empty());
269
270 let mut sumcheck_provers = vec![];
271
272 for (n_vars, desc) in non_empty_sumcheck_descs {
273 let multilins = chain!(
274 packed_committed_fslices[desc.committed_indices.clone()]
275 .iter()
276 .map(|fslice| Hal::DevMem::narrow(fslice)),
277 transparent_multilins[desc.transparent_indices.clone()]
278 .iter()
279 .map(|fslice| Hal::DevMem::narrow(fslice))
280 )
281 .collect::<Vec<_>>();
282
283 let claim = SumcheckClaim::new(n_vars, multilins.len(), desc.composite_sums.clone())?;
284
285 sumcheck_provers
286 .push(BivariateSumcheckProver::new(hal, dev_alloc, host_alloc, &claim, multilins)?);
287 }
288
289 prove_interleaved_fri_sumcheck(
290 hal,
291 dev_alloc,
292 commit_meta.total_vars(),
293 fri_params,
294 ntt,
295 merkle_prover,
296 sumcheck_provers,
297 codeword,
298 &committed,
299 transcript,
300 )?;
301
302 Ok(())
303}
304
305#[allow(clippy::too_many_arguments)]
306fn prove_interleaved_fri_sumcheck<Hal, F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
307 hal: &Hal,
308 dev_alloc: &impl ComputeAllocator<F, Hal::DevMem>,
309 n_rounds: usize,
310 fri_params: &FRIParams<F, FEncode>,
311 ntt: &NTT,
312 merkle_prover: &MTProver,
313 sumcheck_provers: Vec<impl SumcheckProver<F>>,
314 codeword: &[P],
315 committed: &MTProver::Committed,
316 transcript: &mut ProverTranscript<Challenger_>,
317) -> Result<(), Error>
318where
319 Hal: ComputeLayer<F>,
320 F: TowerField,
321 FEncode: BinaryField,
322 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
323 NTT: AdditiveNTT<FEncode> + Sync,
324 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
325 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
326 Challenger_: Challenger,
327{
328 let mut fri_prover = FRIFolder::new(hal, fri_params, ntt, merkle_prover, codeword, committed)?;
329
330 let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
331
332 for round in 0..n_rounds {
333 let _span =
334 tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
335 .entered();
336
337 let bivariate_sumcheck_span = tracing::debug_span!(
338 "[step] Bivariate Sumcheck",
339 phase = "piop_compiler",
340 round = round,
341 perfetto_category = "phase.sub"
342 )
343 .entered();
344 let provers_dimensions_data =
345 SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
346 let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
347 "[task] (PIOP Compiler) Calculate Coeffs",
348 phase = "piop_compiler",
349 round = round,
350 perfetto_category = "task.main",
351 dimensions_data = ?provers_dimensions_data,
352 )
353 .entered();
354
355 sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
356 drop(bivariate_sumcheck_calculate_coeffs_span);
357
358 let challenge = transcript.sample();
359 let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
360 "[task] (PIOP Compiler) Fold (All Rounds)",
361 phase = "piop_compiler",
362 round = round,
363 perfetto_category = "task.main",
364 dimensions_data = ?provers_dimensions_data,
365 )
366 .entered();
367 sumcheck_batch_prover.receive_challenge(challenge)?;
368 drop(bivariate_sumcheck_all_folds_span);
369 drop(bivariate_sumcheck_span);
370
371 let dimensions_data = FriFoldRoundsData::new(
372 round,
373 fri_params.log_batch_size(),
374 fri_prover.current_codeword_len(),
375 );
376 let fri_fold_rounds_span = tracing::debug_span!(
377 "[step] FRI Fold Rounds",
378 phase = "piop_compiler",
379 round = round,
380 perfetto_category = "phase.sub",
381 ?dimensions_data,
382 )
383 .entered();
384 match fri_prover.execute_fold_round(dev_alloc, challenge)? {
385 FoldRoundOutput::NoCommitment => {}
386 FoldRoundOutput::Commitment(round_commitment) => {
387 transcript.message().write(&round_commitment);
388 }
389 }
390 drop(fri_fold_rounds_span);
391 }
392
393 sumcheck_batch_prover.finish(&mut transcript.message())?;
394 fri_prover.finish_proof(transcript)?;
395 Ok(())
396}
397
398pub fn validate_sumcheck_witness<'a, F, P, M, Hal: ComputeLayer<F>>(
399 committed_multilins: &[M],
400 transparent_multilins: &[FSlice<'a, F, Hal>],
401 claims: &[PIOPSumcheckClaim<F>],
402 hal: &Hal,
403) -> Result<(), Error>
404where
405 F: TowerField,
406 P: PackedField<Scalar = F>,
407 M: MultilinearPoly<P> + Send + Sync,
408{
409 let packed_committed = committed_multilins
410 .iter()
411 .enumerate()
412 .map(|(i, unpacked_committed)| {
413 packed_committed(OracleId::from_index(i), unpacked_committed)
414 })
415 .collect::<Result<Vec<_>, _>>()?;
416
417 for (i, claim) in claims.iter().enumerate() {
418 let committed = &packed_committed[claim.committed];
419 if committed.n_vars() != claim.n_vars {
420 bail!(sumcheck::Error::NumberOfVariablesMismatch);
421 }
422
423 let transparent = &transparent_multilins[claim.transparent];
424 if transparent.len() != 1 << claim.n_vars {
425 bail!(sumcheck::Error::NumberOfVariablesMismatch);
426 }
427
428 let mut transparent_evals = zeroed_vec(transparent.len());
429
430 hal.copy_d2h(*transparent, &mut transparent_evals)?;
431
432 let sum = (0..(1 << claim.n_vars))
433 .into_par_iter()
434 .map(|j| {
435 let committed_eval = committed
436 .evaluate_on_hypercube(j)
437 .expect("j is less than 1 << n_vars; committed.n_vars is checked above");
438 let transparent_eval = transparent_evals
439 .get(j)
440 .expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
441 committed_eval * transparent_eval
442 })
443 .sum::<F>();
444
445 if sum != claim.sum {
446 bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
447 composition_index: i,
448 });
449 }
450 }
451 Ok(())
452}
453
454fn packed_committed<F, P, M>(
462 id: OracleId,
463 unpacked_committed: &M,
464) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
465where
466 F: TowerField,
467 P: PackedField<Scalar = F>,
468 M: MultilinearPoly<P>,
469{
470 let unpacked_n_vars = unpacked_committed.n_vars();
471 let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
472 let packed_eval = padded_packed_eval(unpacked_committed);
473 MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
474 } else {
475 let packed_evals = unpacked_committed
476 .packed_evals()
477 .ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
478
479 MultilinearExtension::new(
480 unpacked_n_vars - unpacked_committed.log_extension_degree(),
481 Cow::Borrowed(packed_evals),
482 )
483 }?;
484 Ok(packed_committed)
485}
486
487#[inline]
488fn padded_packed_eval<F, P, M>(multilin: &M) -> F
489where
490 F: TowerField,
491 P: PackedField<Scalar = F>,
492 M: MultilinearPoly<P>,
493{
494 let n_vars = multilin.n_vars();
495 let kappa = multilin.log_extension_degree();
496 assert!(n_vars < kappa);
497
498 (0..1 << kappa)
499 .map(|i| {
500 let iota = F::TOWER_LEVEL - kappa;
501 let scalar = <F as TowerField>::basis(iota, i)
502 .expect("i is in range 0..1 << log_extension_degree");
503 multilin
504 .evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
505 .expect("i is in range 0..1 << n_vars")
506 })
507 .sum()
508}
509
510#[cfg(test)]
511mod tests {
512 use std::iter::repeat_with;
513
514 use binius_field::PackedBinaryField2x128b;
515 use rand::{SeedableRng, rngs::StdRng};
516
517 use super::*;
518
519 #[test]
520 fn test_merge_multilins() {
521 let mut rng = StdRng::seed_from_u64(0);
522
523 let multilins = (0usize..8)
524 .map(|n_vars| {
525 let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
526 .take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
527 .collect::<Vec<_>>();
528
529 MultilinearExtension::new(n_vars, data).unwrap()
530 })
531 .collect::<Vec<_>>();
532 let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
533 let mut buffer =
534 vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
535 merge_multilins(&multilins, &mut buffer);
536
537 let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
538 let mut offset = 0;
539 for multilin in multilins.iter().rev() {
540 let scalars = &scalars[offset..];
541 for (i, v) in PackedField::iter_slice(multilin.evals())
542 .take(1 << multilin.n_vars())
543 .enumerate()
544 {
545 assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
546 }
547 offset += 1 << multilin.n_vars();
548 }
549 }
550}