1use std::{borrow::Cow, ops::Deref};
4
5use binius_field::{
6 packed::PackedSliceMut, BinaryField, Field, PackedExtension, PackedField, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::{
10 EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
11 MultilinearPoly,
12};
13use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
14use binius_ntt::AdditiveNTT;
15use binius_utils::{
16 bail,
17 checked_arithmetics::checked_log_2,
18 random_access_sequence::{RandomAccessSequenceMut, SequenceSubrangeMut},
19 sorting::is_sorted_ascending,
20 SerializeBytes,
21};
22use either::Either;
23use itertools::{chain, Itertools};
24
25use super::{
26 error::Error,
27 verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
28};
29use crate::{
30 fiat_shamir::{CanSample, Challenger},
31 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
32 piop::{
33 logging::{FriFoldRoundsData, SumcheckBatchProverDimensionsData},
34 CommitMeta,
35 },
36 protocols::{
37 fri,
38 fri::{FRIFolder, FRIParams, FoldRoundOutput},
39 sumcheck,
40 sumcheck::{
41 immediate_switchover_heuristic,
42 prove::{
43 front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
44 SumcheckProver,
45 },
46 },
47 },
48 transcript::ProverTranscript,
49};
50
51#[inline(always)]
52fn reverse_bits(x: usize, log_len: usize) -> usize {
53 x.reverse_bits()
54 .wrapping_shr((usize::BITS as usize - log_len) as _)
55}
56
57fn reverse_index_bits<T: Copy>(collection: &mut impl RandomAccessSequenceMut<T>) {
60 let log_len = checked_log_2(collection.len());
61 for i in 0..collection.len() {
62 let bit_reversed_index = reverse_bits(i, log_len);
63 if i < bit_reversed_index {
64 unsafe {
66 let tmp = collection.get_unchecked(i);
67 collection.set_unchecked(i, collection.get_unchecked(bit_reversed_index));
68 collection.set_unchecked(bit_reversed_index, tmp);
69 }
70 }
71 }
72}
73
74fn merge_multilins<F, P, Data>(
82 multilins: &[MultilinearExtension<P, Data>],
83 message_buffer: &mut [P],
84) where
85 F: TowerField,
86 P: PackedField<Scalar = F>,
87 Data: Deref<Target = [P]>,
88{
89 let mut mle_iter = multilins.iter().rev();
90
91 let mut full_packed_mles = Vec::new(); let mut remaining_buffer = message_buffer;
95 for mle in mle_iter.peeking_take_while(|mle| mle.n_vars() >= P::LOG_WIDTH) {
96 let evals = mle.evals();
97 let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
98 full_packed_mles.push((evals, chunk));
99 remaining_buffer = rest;
100 }
101 full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
102 chunk.copy_from_slice(evals);
103 reverse_index_bits(&mut PackedSliceMut::new(chunk));
104 });
105
106 let mut scalar_offset = 0;
109 let mut remaining_buffer = PackedSliceMut::new(remaining_buffer);
110 for mle in mle_iter {
111 let packed_eval = mle.evals()[0];
112 let len = 1 << mle.n_vars();
113 let mut packed_chunk = SequenceSubrangeMut::new(&mut remaining_buffer, scalar_offset, len);
114 for i in 0..len {
115 packed_chunk.set(i, packed_eval.get(i));
116 }
117 reverse_index_bits(&mut packed_chunk);
118
119 scalar_offset += len;
120 }
121}
122
123pub fn commit<F, FEncode, P, M, NTT, MTScheme, MTProver>(
138 fri_params: &FRIParams<F, FEncode>,
139 ntt: &NTT,
140 merkle_prover: &MTProver,
141 multilins: &[M],
142) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
143where
144 F: TowerField,
145 FEncode: BinaryField,
146 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
147 M: MultilinearPoly<P>,
148 NTT: AdditiveNTT<FEncode> + Sync,
149 MTScheme: MerkleTreeScheme<F>,
150 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
151{
152 let packed_multilins = multilins
153 .iter()
154 .enumerate()
155 .map(|(i, unpacked_committed)| packed_committed(i, unpacked_committed))
156 .collect::<Result<Vec<_>, _>>()?;
157 if !is_sorted_ascending(packed_multilins.iter().map(|mle| mle.n_vars())) {
158 return Err(Error::CommittedsNotSorted);
159 }
160
161 let output = fri::commit_interleaved_with(fri_params, ntt, merkle_prover, |message_buffer| {
162 merge_multilins(&packed_multilins, message_buffer)
163 })?;
164
165 Ok(output)
166}
167
168#[allow(clippy::too_many_arguments)]
173pub fn prove<
174 F,
175 FDomain,
176 FEncode,
177 P,
178 M,
179 NTT,
180 DomainFactory,
181 MTScheme,
182 MTProver,
183 Challenger_,
184 Backend,
185>(
186 fri_params: &FRIParams<F, FEncode>,
187 ntt: &NTT,
188 merkle_prover: &MTProver,
189 domain_factory: DomainFactory,
190 commit_meta: &CommitMeta,
191 committed: MTProver::Committed,
192 codeword: &[P],
193 committed_multilins: &[M],
194 transparent_multilins: &[M],
195 claims: &[PIOPSumcheckClaim<F>],
196 transcript: &mut ProverTranscript<Challenger_>,
197 backend: &Backend,
198) -> Result<(), Error>
199where
200 F: TowerField,
201 FDomain: Field,
202 FEncode: BinaryField,
203 P: PackedField<Scalar = F>
204 + PackedExtension<F, PackedSubfield = P>
205 + PackedExtension<FDomain>
206 + PackedExtension<FEncode>,
207 M: MultilinearPoly<P> + Send + Sync,
208 NTT: AdditiveNTT<FEncode> + Sync,
209 DomainFactory: EvaluationDomainFactory<FDomain>,
210 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212 Challenger_: Challenger,
213 Backend: ComputationBackend,
214{
215 let sumcheck_claim_descs = make_sumcheck_claim_descs(
217 commit_meta,
218 transparent_multilins.iter().map(|poly| poly.n_vars()),
219 claims,
220 )?;
221
222 let packed_committed_multilins = committed_multilins
226 .iter()
227 .enumerate()
228 .map(|(i, unpacked_committed)| {
229 packed_committed(i, unpacked_committed).map(MLEDirectAdapter::from)
230 })
231 .collect::<Result<Vec<_>, _>>()?;
232
233 let non_empty_sumcheck_descs = sumcheck_claim_descs
234 .iter()
235 .enumerate()
236 .filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
240 let sumcheck_provers = non_empty_sumcheck_descs
241 .map(|(_n_vars, desc)| {
242 let multilins = chain!(
243 packed_committed_multilins[desc.committed_indices.clone()]
244 .iter()
245 .map(Either::Left),
246 transparent_multilins[desc.transparent_indices.clone()]
247 .iter()
248 .map(Either::Right),
249 )
250 .collect::<Vec<_>>();
251 RegularSumcheckProver::new(
252 EvaluationOrder::HighToLow,
253 multilins,
254 desc.composite_sums.iter().cloned(),
255 &domain_factory,
256 immediate_switchover_heuristic,
257 backend,
258 )
259 })
260 .collect::<Result<Vec<_>, _>>()?;
261
262 prove_interleaved_fri_sumcheck(
263 commit_meta.total_vars(),
264 fri_params,
265 ntt,
266 merkle_prover,
267 sumcheck_provers,
268 codeword,
269 &committed,
270 transcript,
271 )?;
272
273 Ok(())
274}
275
276#[allow(clippy::too_many_arguments)]
277fn prove_interleaved_fri_sumcheck<F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
278 n_rounds: usize,
279 fri_params: &FRIParams<F, FEncode>,
280 ntt: &NTT,
281 merkle_prover: &MTProver,
282 sumcheck_provers: Vec<impl SumcheckProver<F>>,
283 codeword: &[P],
284 committed: &MTProver::Committed,
285 transcript: &mut ProverTranscript<Challenger_>,
286) -> Result<(), Error>
287where
288 F: TowerField,
289 FEncode: BinaryField,
290 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
291 NTT: AdditiveNTT<FEncode> + Sync,
292 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
293 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
294 Challenger_: Challenger,
295{
296 let mut fri_prover = FRIFolder::new(fri_params, ntt, merkle_prover, codeword, committed)?;
297
298 let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
299
300 for round in 0..n_rounds {
301 let _span =
302 tracing::debug_span!("PIOP Compiler Round", phase = "piop_compiler", round = round)
303 .entered();
304
305 let bivariate_sumcheck_span = tracing::debug_span!(
306 "[step] Bivariate Sumcheck",
307 phase = "piop_compiler",
308 round = round,
309 perfetto_category = "phase.sub"
310 )
311 .entered();
312 let provers_dimensions_data =
313 SumcheckBatchProverDimensionsData::new(round, sumcheck_batch_prover.provers());
314 let bivariate_sumcheck_calculate_coeffs_span = tracing::debug_span!(
315 "[task] (PIOP Compiler) Calculate Coeffs",
316 phase = "piop_compiler",
317 round = round,
318 perfetto_category = "task.main",
319 dimensions_data = ?provers_dimensions_data,
320 )
321 .entered();
322 sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
323 drop(bivariate_sumcheck_calculate_coeffs_span);
324
325 let challenge = transcript.sample();
326 let bivariate_sumcheck_all_folds_span = tracing::debug_span!(
327 "[task] (PIOP Compiler) Fold (All Rounds)",
328 phase = "piop_compiler",
329 round = round,
330 perfetto_category = "task.main",
331 dimensions_data = ?provers_dimensions_data,
332 )
333 .entered();
334 sumcheck_batch_prover.receive_challenge(challenge)?;
335 drop(bivariate_sumcheck_all_folds_span);
336 drop(bivariate_sumcheck_span);
337
338 let dimensions_data = FriFoldRoundsData::new(
339 round,
340 fri_params.log_batch_size(),
341 fri_prover.current_codeword_len(),
342 );
343 let fri_fold_rounds_span = tracing::debug_span!(
344 "[step] FRI Fold Rounds",
345 phase = "piop_compiler",
346 round = round,
347 perfetto_category = "phase.sub",
348 dimensions_data = ?dimensions_data,
349 )
350 .entered();
351 match fri_prover.execute_fold_round(challenge)? {
352 FoldRoundOutput::NoCommitment => {}
353 FoldRoundOutput::Commitment(round_commitment) => {
354 transcript.message().write(&round_commitment);
355 }
356 }
357 drop(fri_fold_rounds_span);
358 }
359
360 sumcheck_batch_prover.finish(&mut transcript.message())?;
361 fri_prover.finish_proof(transcript)?;
362 Ok(())
363}
364
365pub fn validate_sumcheck_witness<F, P, M>(
366 committed_multilins: &[M],
367 transparent_multilins: &[M],
368 claims: &[PIOPSumcheckClaim<F>],
369) -> Result<(), Error>
370where
371 F: TowerField,
372 P: PackedField<Scalar = F>,
373 M: MultilinearPoly<P> + Send + Sync,
374{
375 let packed_committed = committed_multilins
376 .iter()
377 .enumerate()
378 .map(|(i, unpacked_committed)| packed_committed(i, unpacked_committed))
379 .collect::<Result<Vec<_>, _>>()?;
380
381 for (i, claim) in claims.iter().enumerate() {
382 let committed = &packed_committed[claim.committed];
383 if committed.n_vars() != claim.n_vars {
384 bail!(sumcheck::Error::NumberOfVariablesMismatch);
385 }
386
387 let transparent = &transparent_multilins[claim.transparent];
388 if transparent.n_vars() != claim.n_vars {
389 bail!(sumcheck::Error::NumberOfVariablesMismatch);
390 }
391
392 let sum = (0..(1 << claim.n_vars))
393 .into_par_iter()
394 .map(|j| {
395 let committed_eval = committed
396 .evaluate_on_hypercube(j)
397 .expect("j is less than 1 << n_vars; committed.n_vars is checked above");
398 let transparent_eval = transparent
399 .evaluate_on_hypercube(j)
400 .expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
401 committed_eval * transparent_eval
402 })
403 .sum::<F>();
404
405 if sum != claim.sum {
406 bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
407 composition_index: i,
408 });
409 }
410 }
411 Ok(())
412}
413
414fn packed_committed<F, P, M>(
422 id: usize,
423 unpacked_committed: &M,
424) -> Result<MultilinearExtension<P, Cow<'_, [P]>>, Error>
425where
426 F: TowerField,
427 P: PackedField<Scalar = F>,
428 M: MultilinearPoly<P>,
429{
430 let unpacked_n_vars = unpacked_committed.n_vars();
431 let packed_committed = if unpacked_n_vars < unpacked_committed.log_extension_degree() {
432 let packed_eval = padded_packed_eval(unpacked_committed);
433 MultilinearExtension::new(0, Cow::Owned(vec![P::set_single(packed_eval)]))
434 } else {
435 let packed_evals = unpacked_committed
436 .packed_evals()
437 .ok_or(Error::CommittedPackedEvaluationsMissing { id })?;
438
439 MultilinearExtension::new(
440 unpacked_n_vars - unpacked_committed.log_extension_degree(),
441 Cow::Borrowed(packed_evals),
442 )
443 }?;
444 Ok(packed_committed)
445}
446
447#[inline]
448fn padded_packed_eval<F, P, M>(multilin: &M) -> F
449where
450 F: TowerField,
451 P: PackedField<Scalar = F>,
452 M: MultilinearPoly<P>,
453{
454 let n_vars = multilin.n_vars();
455 let kappa = multilin.log_extension_degree();
456 assert!(n_vars < kappa);
457
458 (0..1 << kappa)
459 .map(|i| {
460 let iota = F::TOWER_LEVEL - kappa;
461 let scalar = <F as TowerField>::basis(iota, i)
462 .expect("i is in range 0..1 << log_extension_degree");
463 multilin
464 .evaluate_on_hypercube_and_scale(i % (1 << n_vars), scalar)
465 .expect("i is in range 0..1 << n_vars")
466 })
467 .sum()
468}
469
470#[cfg(test)]
471mod tests {
472 use std::iter::repeat_with;
473
474 use binius_field::PackedBinaryField2x128b;
475 use rand::{rngs::StdRng, SeedableRng};
476
477 use super::*;
478
479 #[test]
480 fn test_merge_multilins() {
481 let mut rng = StdRng::seed_from_u64(0);
482
483 let multilins = (0usize..8)
484 .map(|n_vars| {
485 let data = repeat_with(|| PackedBinaryField2x128b::random(&mut rng))
486 .take(1 << n_vars.saturating_sub(PackedBinaryField2x128b::LOG_WIDTH))
487 .collect::<Vec<_>>();
488
489 MultilinearExtension::new(n_vars, data).unwrap()
490 })
491 .collect::<Vec<_>>();
492 let scalars = (0..8).map(|i| 1usize << i).sum::<usize>();
493 let mut buffer =
494 vec![PackedBinaryField2x128b::zero(); scalars.div_ceil(PackedBinaryField2x128b::WIDTH)];
495 merge_multilins(&multilins, &mut buffer);
496
497 let scalars = PackedField::iter_slice(&buffer).take(scalars).collect_vec();
498 let mut offset = 0;
499 for multilin in multilins.iter().rev() {
500 let scalars = &scalars[offset..];
501 for (i, v) in PackedField::iter_slice(multilin.evals())
502 .take(1 << multilin.n_vars())
503 .enumerate()
504 {
505 assert_eq!(scalars[reverse_bits(i, multilin.n_vars())], v);
506 }
507 offset += 1 << multilin.n_vars();
508 }
509 }
510}