1use binius_field::{
4 packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField,
5 PackedFieldIndexable, TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{
9 EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
10 MultilinearPoly,
11};
12use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*};
13use binius_ntt::{NTTOptions, ThreadingSettings};
14use binius_utils::{bail, sorting::is_sorted_ascending, SerializeBytes};
15use either::Either;
16use itertools::{chain, Itertools};
17
18use super::{
19 error::Error,
20 verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim},
21};
22use crate::{
23 fiat_shamir::{CanSample, Challenger},
24 merkle_tree::{MerkleTreeProver, MerkleTreeScheme},
25 piop::CommitMeta,
26 protocols::{
27 fri,
28 fri::{FRIFolder, FRIParams, FoldRoundOutput},
29 sumcheck,
30 sumcheck::{
31 immediate_switchover_heuristic,
32 prove::{
33 front_loaded::BatchProver as SumcheckBatchProver, RegularSumcheckProver,
34 SumcheckProver,
35 },
36 },
37 },
38 reed_solomon::reed_solomon::ReedSolomonCode,
39 transcript::ProverTranscript,
40};
41
42fn merge_multilins<P, M>(multilins: &[M], message_buffer: &mut [P])
50where
51 P: PackedField,
52 M: MultilinearPoly<P>,
53{
54 let mut mle_iter = multilins.iter().rev();
55
56 let get_n_packed_vars = |mle: &M| mle.n_vars() - mle.log_extension_degree();
59 let mut full_packed_mles = Vec::new(); let mut remaining_buffer = message_buffer;
61 for mle in mle_iter.peeking_take_while(|mle| get_n_packed_vars(mle) >= P::LOG_WIDTH) {
62 let evals = mle
63 .packed_evals()
64 .expect("guaranteed by function precondition");
65 let (chunk, rest) = remaining_buffer.split_at_mut(evals.len());
66 full_packed_mles.push((evals, chunk));
67 remaining_buffer = rest;
68 }
69 full_packed_mles.into_par_iter().for_each(|(evals, chunk)| {
70 chunk.copy_from_slice(evals);
71 });
72
73 let mut scalar_offset = 0;
76 for mle in mle_iter {
77 let evals = mle
78 .packed_evals()
79 .expect("guaranteed by function precondition");
80 let packed_eval = evals[0];
81 for i in 0..1 << mle.n_vars() {
82 set_packed_slice(remaining_buffer, scalar_offset, packed_eval.get(i));
83 scalar_offset += 1;
84 }
85 }
86}
87
88#[tracing::instrument("piop::commit", skip_all)]
103pub fn commit<F, FEncode, P, M, MTScheme, MTProver>(
104 fri_params: &FRIParams<F, FEncode>,
105 merkle_prover: &MTProver,
106 multilins: &[M],
107) -> Result<fri::CommitOutput<P, MTScheme::Digest, MTProver::Committed>, Error>
108where
109 F: BinaryField,
110 FEncode: BinaryField,
111 P: PackedField<Scalar = F> + PackedExtension<FEncode>,
112 M: MultilinearPoly<P>,
113 MTScheme: MerkleTreeScheme<F>,
114 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
115{
116 for (i, multilin) in multilins.iter().enumerate() {
117 if multilin.n_vars() < multilin.log_extension_degree() {
118 return Err(Error::OracleTooSmall {
119 id: i,
122 min_vars: multilin.log_extension_degree(),
123 });
124 }
125 if multilin.packed_evals().is_none() {
126 return Err(Error::CommittedPackedEvaluationsMissing { id: i });
127 }
128 }
129
130 let n_packed_vars = multilins
131 .iter()
132 .map(|multilin| multilin.n_vars() - multilin.log_extension_degree());
133 if !is_sorted_ascending(n_packed_vars) {
134 return Err(Error::CommittedsNotSorted);
135 }
136
137 let rs_code = ReedSolomonCode::new(
139 fri_params.rs_code().log_dim(),
140 fri_params.rs_code().log_inv_rate(),
141 &NTTOptions {
142 precompute_twiddles: true,
143 thread_settings: ThreadingSettings::MultithreadedDefault,
144 },
145 )?;
146 let output =
147 fri::commit_interleaved_with(&rs_code, fri_params, merkle_prover, |message_buffer| {
148 merge_multilins(multilins, message_buffer)
149 })?;
150
151 Ok(output)
152}
153
154#[allow(clippy::too_many_arguments)]
159#[tracing::instrument("piop::prove", skip_all)]
160pub fn prove<F, FDomain, FEncode, P, M, DomainFactory, MTScheme, MTProver, Challenger_, Backend>(
161 fri_params: &FRIParams<F, FEncode>,
162 merkle_prover: &MTProver,
163 domain_factory: DomainFactory,
164 commit_meta: &CommitMeta,
165 committed: MTProver::Committed,
166 codeword: &[P],
167 committed_multilins: &[M],
168 transparent_multilins: &[M],
169 claims: &[PIOPSumcheckClaim<F>],
170 transcript: &mut ProverTranscript<Challenger_>,
171 backend: &Backend,
172) -> Result<(), Error>
173where
174 F: TowerField,
175 FDomain: Field,
176 FEncode: BinaryField,
177 P: PackedFieldIndexable<Scalar = F>
178 + PackedExtension<F, PackedSubfield = P>
179 + PackedExtension<FDomain>
180 + PackedExtension<FEncode>,
181 M: MultilinearPoly<P> + Send + Sync,
182 DomainFactory: EvaluationDomainFactory<FDomain>,
183 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
184 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
185 Challenger_: Challenger,
186 Backend: ComputationBackend,
187{
188 let sumcheck_claim_descs = make_sumcheck_claim_descs(
190 commit_meta,
191 transparent_multilins.iter().map(|poly| poly.n_vars()),
192 claims,
193 )?;
194
195 let packed_committed_multilins = committed_multilins
199 .iter()
200 .enumerate()
201 .map(|(i, committed_multilin)| {
202 let packed_evals = committed_multilin
203 .packed_evals()
204 .ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
205 let packed_multilin = MultilinearExtension::from_values_slice(packed_evals)?;
206 Ok::<_, Error>(MLEDirectAdapter::from(packed_multilin))
207 })
208 .collect::<Result<Vec<_>, _>>()?;
209
210 let non_empty_sumcheck_descs = sumcheck_claim_descs
211 .iter()
212 .enumerate()
213 .filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
214 let sumcheck_provers = non_empty_sumcheck_descs
215 .clone()
216 .map(|(_n_vars, desc)| {
217 let multilins = chain!(
218 packed_committed_multilins[desc.committed_indices.clone()]
219 .iter()
220 .map(Either::Left),
221 transparent_multilins[desc.transparent_indices.clone()]
222 .iter()
223 .map(Either::Right),
224 )
225 .collect::<Vec<_>>();
226 RegularSumcheckProver::new(
227 EvaluationOrder::LowToHigh,
228 multilins,
229 desc.composite_sums.iter().cloned(),
230 &domain_factory,
231 immediate_switchover_heuristic,
232 backend,
233 )
234 })
235 .collect::<Result<Vec<_>, _>>()?;
236
237 prove_interleaved_fri_sumcheck(
238 commit_meta.total_vars(),
239 fri_params,
240 merkle_prover,
241 sumcheck_provers,
242 codeword,
243 &committed,
244 transcript,
245 )?;
246
247 Ok(())
248}
249
250fn prove_interleaved_fri_sumcheck<F, FEncode, P, MTScheme, MTProver, Challenger_>(
251 n_rounds: usize,
252 fri_params: &FRIParams<F, FEncode>,
253 merkle_prover: &MTProver,
254 sumcheck_provers: Vec<impl SumcheckProver<F>>,
255 codeword: &[P],
256 committed: &MTProver::Committed,
257 transcript: &mut ProverTranscript<Challenger_>,
258) -> Result<(), Error>
259where
260 F: TowerField,
261 FEncode: BinaryField,
262 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FEncode>,
263 MTScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
264 MTProver: MerkleTreeProver<F, Scheme = MTScheme>,
265 Challenger_: Challenger,
266{
267 let mut fri_prover =
268 FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?;
269
270 let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?;
271
272 for _ in 0..n_rounds {
273 sumcheck_batch_prover.send_round_proof(&mut transcript.message())?;
274 let challenge = transcript.sample();
275 sumcheck_batch_prover.receive_challenge(challenge)?;
276
277 match fri_prover.execute_fold_round(challenge)? {
278 FoldRoundOutput::NoCommitment => {}
279 FoldRoundOutput::Commitment(round_commitment) => {
280 transcript.message().write(&round_commitment);
281 }
282 }
283 }
284
285 sumcheck_batch_prover.finish(&mut transcript.message())?;
286 fri_prover.finish_proof(transcript)?;
287 Ok(())
288}
289
290pub fn validate_sumcheck_witness<F, P, M>(
291 committed_multilins: &[M],
292 transparent_multilins: &[M],
293 claims: &[PIOPSumcheckClaim<F>],
294) -> Result<(), Error>
295where
296 F: TowerField,
297 P: PackedField<Scalar = F>,
298 M: MultilinearPoly<P> + Send + Sync,
299{
300 let packed_committed = committed_multilins
301 .iter()
302 .enumerate()
303 .map(|(i, unpacked_committed)| {
304 let packed_evals = unpacked_committed
305 .packed_evals()
306 .ok_or(Error::CommittedPackedEvaluationsMissing { id: i })?;
307 let packed_committed = MultilinearExtension::from_values_slice(packed_evals)?;
308 Ok::<_, Error>(packed_committed)
309 })
310 .collect::<Result<Vec<_>, _>>()?;
311
312 for (i, claim) in claims.iter().enumerate() {
313 let committed = &packed_committed[claim.committed];
314 if committed.n_vars() != claim.n_vars {
315 bail!(sumcheck::Error::NumberOfVariablesMismatch);
316 }
317
318 let transparent = &transparent_multilins[claim.transparent];
319 if transparent.n_vars() != claim.n_vars {
320 bail!(sumcheck::Error::NumberOfVariablesMismatch);
321 }
322
323 let sum = (0..(1 << claim.n_vars))
324 .into_par_iter()
325 .map(|j| {
326 let committed_eval = committed
327 .evaluate_on_hypercube(j)
328 .expect("j is less than 1 << n_vars; committed.n_vars is checked above");
329 let transparent_eval = transparent
330 .evaluate_on_hypercube(j)
331 .expect("j is less than 1 << n_vars; transparent.n_vars is checked above");
332 committed_eval * transparent_eval
333 })
334 .sum::<F>();
335
336 if sum != claim.sum {
337 bail!(sumcheck::Error::SumcheckNaiveValidationFailure {
338 composition_index: i,
339 });
340 }
341 }
342 Ok(())
343}