1use binius_field::{
4 packed::{get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked},
5 BinaryField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{
9 EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
10 MultilinearPoly,
11};
12use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
13use binius_ntt::{NTTOptions, ThreadingSettings};
14use binius_utils::{
15 bail, checked_arithmetics::checked_log_2, sorting::is_sorted_ascending, SerializeBytes,
16};
17use either::Either;
18use itertools::{chain, Itertools};
19
20use super::{
21 error::Error,
22 verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
23};
24use crate::{
25 fiat_shamir::{CanSample, Challenger},
26 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
27 piop::CommitMeta,
28 protocols::{
29 fri,
30 fri::{FRIFolder, FRIParams, FoldRoundOutput},
31 sumcheck,
32 sumcheck::{
33 immediate_switchover_heuristic,
34 prove::{
35 front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
36 SumcheckProver,
37 },
38 },
39 },
40 reed_solomon::reed_solomon::ReedSolomonCode,
41 transcript::ProverTranscript,
42};
43
44fn reverse_slice_index_bits<P: PackedField>(slice: &mut [P]) {
47 let log_len = checked_log_2(slice.len()) + P::LOG_WIDTH;
48 for i in 0..slice.len() << P::LOG_WIDTH {
49 let bit_reversed_index = i
50 .reverse_bits()
51 .wrapping_shr((usize::BITS as usize - log_len) as _);
52 if i < bit_reversed_index {
53 unsafe {
55 let tmp = get_packed_slice_unchecked(slice, i);
56 set_packed_slice_unchecked(
57 slice,
58 i,
59 get_packed_slice_unchecked(slice, bit_reversed_index),
60 );
61 set_packed_slice_unchecked(slice, bit_reversed_index, tmp);
62 }
63 }
64 }
65}
66
67fn merge_multilins<P, M>(multilins: &[M], message_buffer: &mut [P])
75where
76 P: PackedField,
77 M: MultilinearPoly<P>,
78{
79 let mut mle_iter = multilins.iter().rev();
80
81 let get_n_packed_vars = |mle: &M| mle.n_vars() - mle.log_extension_degree();
84 let mut full_packed_mles = Vec::new(); let mut remaining_buffer = message_buffer;
86 for mle in mle_iter.peeking_take_while(|mle| get_n_packed_vars(mle) >= P::LOG_WIDTH) {
87 let evals = mle
88 .packed_evals()
89 .expect("guaranteed by function precondition");
90 let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
91 full_packed_mles.push((evals, chunk));
92 remaining_buffer = rest;
93 }
94 full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
95 chunk.copy_from_slice(evals);
96 reverse_slice_index_bits(chunk);
97 });
98
99 let mut scalar_offset = 0;
102 for mle in mle_iter {
103 let evals = mle
104 .packed_evals()
105 .expect("guaranteed by function precondition");
106 let packed_eval = evals[0];
107 for i in 0..1 << mle.n_vars() {
108 set_packed_slice(remaining_buffer, scalar_offset, packed_eval.get(i));
109 scalar_offset += 1;
110 }
111 }
112}
113
114#[tracing::instrument("piop::commit", skip_all)]
129pub fn commit<F, FEncode, P, M, MTScheme, MTProver>(
130 fri_params: &FRIParams<F, FEncode>,
131 merkle_prover: &MTProver,
132 multilins: &[M],
133) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
134where
135 F: BinaryField,
136 FEncode: BinaryField,
137 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
138 M: MultilinearPoly<P>,
139 MTScheme: MerkleTreeScheme<F>,
140 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
141{
142 for (i, multilin) in multilins.iter().enumerate() {
143 if multilin.n_vars() < multilin.log_extension_degree() {
144 return Err(Error::OracleTooSmall {
145 id: i,
148 n_vars: multilin.n_vars(),
149 min_vars: multilin.log_extension_degree(),
150 });
151 }
152 if multilin.packed_evals().is_none() {
153 return Err(Error::CommittedPackedEvaluationsMissing { id: i });
154 }
155 }
156
157 let n_packed_vars = multilins
158 .iter()
159 .map(|multilin| multilin.n_vars() - multilin.log_extension_degree());
160 if !is_sorted_ascending(n_packed_vars) {
161 return Err(Error::CommittedsNotSorted);
162 }
163
164 let rs_code = ReedSolomonCode::new(
166 fri_params.rs_code().log_dim(),
167 fri_params.rs_code().log_inv_rate(),
168 &NTTOptions {
169 precompute_twiddles: true,
170 thread_settings: ThreadingSettings::MultithreadedDefault,
171 },
172 )?;
173 let output =
174 fri::commit_interleaved_with(&rs_code, fri_params, merkle_prover, |message_buffer| {
175 merge_multilins(multilins, message_buffer)
176 })?;
177
178 Ok(output)
179}
180
181#[allow(clippy::too_many_arguments)]
186#[tracing::instrument("piop::prove", skip_all)]
187pub fn prove<F, FDomain, FEncode, P, M, DomainFactory, MTScheme, MTProver, Challenger_, Backend>(
188 fri_params: &FRIParams<F, FEncode>,
189 merkle_prover: &MTProver,
190 domain_factory: DomainFactory,
191 commit_meta: &CommitMeta,
192 committed: MTProver::Committed,
193 codeword: &[P],
194 committed_multilins: &[M],
195 transparent_multilins: &[M],
196 claims: &[PIOPSumcheckClaim<F>],
197 transcript: &mut ProverTranscript<Challenger_>,
198 backend: &Backend,
199) -> Result<(), Error>
200where
201 F: TowerField,
202 FDomain: Field,
203 FEncode: BinaryField,
204 P: PackedFieldIndexable<Scalar = F>
205 + PackedExtension<F, PackedSubfield = P>
206 + PackedExtension<FDomain>
207 + PackedExtension<FEncode>,
208 M: MultilinearPoly<P> + Send + Sync,
209 DomainFactory: EvaluationDomainFactory<FDomain>,
210 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
211 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
212 Challenger_: Challenger,
213 Backend: ComputationBackend,
214{
215 let sumcheck_claim_descs = make_sumcheck_claim_descs(
217 commit_meta,
218 transparent_multilins.iter().map(|poly| poly.n_vars()),
219 claims,
220 )?;
221
222 let packed_committed_multilins = committed_multilins
226 .iter()
227 .enumerate()
228 .map(|(i, committed_multilin)| {
229 let packed_evals = committed_multilin
230 .packed_evals()
231 .ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
232 let packed_multilin = MultilinearExtension::from_values_slice(packed_evals)?;
233 Ok::<_, Error>(MLEDirectAdapter::from(packed_multilin))
234 })
235 .collect::<Result<Vec<_>, _>>()?;
236
237 let non_empty_sumcheck_descs = sumcheck_claim_descs
238 .iter()
239 .enumerate()
240 .filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
241 let sumcheck_provers = non_empty_sumcheck_descs
242 .clone()
243 .map(|(_n_vars, desc)| {
244 let multilins = chain!(
245 packed_committed_multilins[desc.committed_indices.clone()]
246 .iter()
247 .map(Either::Left),
248 transparent_multilins[desc.transparent_indices.clone()]
249 .iter()
250 .map(Either::Right),
251 )
252 .collect::<Vec<_>>();
253 RegularSumcheckProver::new(
254 EvaluationOrder::HighToLow,
255 multilins,
256 desc.composite_sums.iter().cloned(),
257 &domain_factory,
258 immediate_switchover_heuristic,
259 backend,
260 )
261 })
262 .collect::<Result<Vec<_>, _>>()?;
263
264 prove_interleaved_fri_sumcheck(
265 commit_meta.total_vars(),
266 fri_params,
267 merkle_prover,
268 sumcheck_provers,
269 codeword,
270 &committed,
271 transcript,
272 )?;
273
274 Ok(())
275}
276
277fn prove_interleaved_fri_sumcheck<F, FEncode, P, MTScheme, MTProver, Challenger_>(
278 n_rounds: usize,
279 fri_params: &FRIParams<F, FEncode>,
280 merkle_prover: &MTProver,
281 sumcheck_provers: Vec<impl SumcheckProver<F>>,
282 codeword: &[P],
283 committed: &MTProver::Committed,
284 transcript: &mut ProverTranscript<Challenger_>,
285) -> Result<(), Error>
286where
287 F: TowerField,
288 FEncode: BinaryField,
289 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
290 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
291 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
292 Challenger_: Challenger,
293{
294 let mut fri_prover =
295 FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?;
296
297 let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
298
299 for _ in 0..n_rounds {
300 sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
301 let challenge = transcript.sample();
302 sumcheck_batch_prover.receive_challenge(challenge)?;
303
304 match fri_prover.execute_fold_round(challenge)? {
305 FoldRoundOutput::NoCommitment => {}
306 FoldRoundOutput::Commitment(round_commitment) => {
307 transcript.message().write(&round_commitment);
308 }
309 }
310 }
311
312 sumcheck_batch_prover.finish(&mut transcript.message())?;
313 fri_prover.finish_proof(transcript)?;
314 Ok(())
315}
316
317pub fn validate_sumcheck_witness<F, P, M>(
318 committed_multilins: &[M],
319 transparent_multilins: &[M],
320 claims: &[PIOPSumcheckClaim<F>],
321) -> Result<(), Error>
322where
323 F: TowerField,
324 P: PackedField<Scalar = F>,
325 M: MultilinearPoly<P> + Send + Sync,
326{
327 let packed_committed = committed_multilins
328 .iter()
329 .enumerate()
330 .map(|(i, unpacked_committed)| {
331 let packed_evals = unpacked_committed
332 .packed_evals()
333 .ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
334 let packed_committed = MultilinearExtension::from_values_slice(packed_evals)?;
335 Ok::<_, Error>(packed_committed)
336 })
337 .collect::<Result<Vec<_>, _>>()?;
338
339 for (i, claim) in claims.iter().enumerate() {
340 let committed = &packed_committed[claim.committed];
341 if committed.n_vars() != claim.n_vars {
342 bail!(sumcheck::Error::NumberOfVariablesMismatch);
343 }
344
345 let transparent = &transparent_multilins[claim.transparent];
346 if transparent.n_vars() != claim.n_vars {
347 bail!(sumcheck::Error::NumberOfVariablesMismatch);
348 }
349
350 let sum = (0..(1 << claim.n_vars))
351 .into_par_iter()
352 .map(|j| {
353 let committed_eval = committed
354 .evaluate_on_hypercube(j)
355 .expect("j is less than 1 << n_vars; committed.n_vars is checked above");
356 let transparent_eval = transparent
357 .evaluate_on_hypercube(j)
358 .expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
359 committed_eval * transparent_eval
360 })
361 .sum::<F>();
362
363 if sum != claim.sum {
364 bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
365 composition_index: i,
366 });
367 }
368 }
369 Ok(())
370}