1use std::collections::HashSet;
4
5use binius_field::{PackedField, TowerField};
6use binius_hal::ComputationBackend;
7use binius_math::MultilinearExtension;
8use binius_maybe_rayon::prelude::*;
9use getset::{Getters, MutGetters};
10use itertools::chain;
11use tracing::instrument;
12
13use super::{
14 error::Error,
15 evalcheck::{EvalcheckHint, EvalcheckMultilinearClaim},
16 serialize_evalcheck_proof,
17 subclaims::{
18 add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_mlecheck_meta,
19 fill_eq_witness_for_composites, MemoizedData, ProjectedBivariateMeta, SumcheckClaims,
20 },
21 EvalPoint, EvalPointOracleIdMap,
22};
23use crate::{
24 fiat_shamir::Challenger,
25 oracle::{
26 ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
27 MultilinearPolyOracle, MultilinearPolyVariant, OracleId,
28 },
29 protocols::evalcheck::{
30 logging::MLEFoldHighDimensionsData,
31 subclaims::{
32 packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
33 shifted_sumcheck_meta, CompositeMLECheckMeta,
34 },
35 },
36 transcript::ProverTranscript,
37 witness::MultilinearExtensionIndex,
38};
39
40#[derive(Getters, MutGetters)]
46pub struct EvalcheckProver<'a, 'b, F, P, Backend>
47where
48 P: PackedField<Scalar = F>,
49 F: TowerField,
50 Backend: ComputationBackend,
51{
52 pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
54 pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
56
57 #[getset(get = "pub", get_mut = "pub")]
59 committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
60
61 claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
63 claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
65 sumcheck_claims: Vec<SumcheckClaims<P::Scalar>>,
67
68 new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
70 pub memoized_data: MemoizedData<'b, P, Backend>,
72 backend: &'a Backend,
73
74 claim_to_index: EvalPointOracleIdMap<usize, F>,
76 visited_claims: EvalPointOracleIdMap<(), F>,
78 evals_memoization: EvalPointOracleIdMap<F, F>,
80 round_claim_index: usize,
82}
83
84impl<'a, 'b, F, P, Backend> EvalcheckProver<'a, 'b, F, P, Backend>
85where
86 P: PackedField<Scalar = F>,
87 F: TowerField,
88 Backend: ComputationBackend,
89{
90 pub fn new(
94 oracles: &'a mut MultilinearOracleSet<F>,
95 witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
96 backend: &'a Backend,
97 ) -> Self {
98 Self {
99 oracles,
100 witness_index,
101 committed_eval_claims: Vec::new(),
102 new_sumchecks_constraints: Vec::new(),
103 claims_queue: Vec::new(),
104 claims_without_evals: Vec::new(),
105 sumcheck_claims: Vec::new(),
106 memoized_data: MemoizedData::new(),
107 backend,
108
109 claim_to_index: EvalPointOracleIdMap::new(),
110 visited_claims: EvalPointOracleIdMap::new(),
111 evals_memoization: EvalPointOracleIdMap::new(),
112 round_claim_index: 0,
113 }
114 }
115
116 pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
118 self.new_sumchecks_constraints
119 .iter_mut()
120 .map(|builder| std::mem::take(builder).build_one(self.oracles))
121 .filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
122 .collect()
123 }
124
125 pub fn prove<Challenger_: Challenger>(
140 &mut self,
141 evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
142 transcript: &mut ProverTranscript<Challenger_>,
143 ) -> Result<(), Error> {
144 self.round_claim_index = 0;
146 self.visited_claims.clear();
147 self.claim_to_index.clear();
148 self.evals_memoization.clear();
149
150 for claim in &evalcheck_claims {
151 if self
152 .evals_memoization
153 .get(claim.id, &claim.eval_point)
154 .is_some()
155 {
156 continue;
157 }
158
159 self.evals_memoization
160 .insert(claim.id, claim.eval_point.clone(), claim.eval);
161 }
162
163 self.claims_queue.extend(evalcheck_claims.clone());
164
165 let mle_fold_full_span = tracing::debug_span!(
168 "[task] MLE Fold Full",
169 phase = "evalcheck",
170 perfetto_category = "task.main"
171 )
172 .entered();
173 while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
174 while !self.claims_queue.is_empty() {
175 std::mem::take(&mut self.claims_queue)
176 .into_iter()
177 .for_each(|claim| self.collect_subclaims_for_memoization(claim));
178 }
179
180 let mut deduplicated_claims_without_evals = HashSet::new();
181
182 for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
183 if self.evals_memoization.get(poly.id(), &eval_point).is_some() {
184 continue;
185 }
186
187 deduplicated_claims_without_evals.insert((poly.id(), eval_point.clone()));
188 }
189
190 let deduplicated_eval_points = deduplicated_claims_without_evals
191 .iter()
192 .map(|(_, eval_point)| eval_point.as_ref())
193 .collect::<Vec<_>>();
194
195 self.memoized_data
197 .memoize_query_par(deduplicated_eval_points.iter().copied(), self.backend)?;
198
199 let subclaims = deduplicated_claims_without_evals
201 .into_par_iter()
202 .map(|(id, eval_point)| {
203 Self::make_new_eval_claim(
204 id,
205 eval_point,
206 self.witness_index,
207 &self.memoized_data,
208 )
209 })
210 .collect::<Result<Vec<_>, Error>>()?;
211
212 for subclaim in &subclaims {
213 self.evals_memoization.insert(
214 subclaim.id,
215 subclaim.eval_point.clone(),
216 subclaim.eval,
217 );
218 }
219
220 subclaims
221 .into_iter()
222 .for_each(|claim| self.collect_subclaims_for_memoization(claim));
223 }
224 drop(mle_fold_full_span);
225
226 for claim in evalcheck_claims {
229 self.prove_multilinear(claim, transcript)?;
230 }
231
232 let mut projected_bivariate_metas = Vec::new();
234 let mut composite_mle_metas = Vec::new();
235 let mut projected_bivariate_claims = Vec::new();
236 let mut composite_mle_claims = Vec::new();
237
238 for claim in &self.sumcheck_claims {
239 match claim {
240 SumcheckClaims::Projected(claim) => {
241 let meta = Self::projected_bivariate_meta(self.oracles, claim)?;
242 projected_bivariate_metas.push(meta);
243 projected_bivariate_claims.push(claim.clone())
244 }
245 SumcheckClaims::Composite(claim) => {
246 let meta = composite_mlecheck_meta(self.oracles, &claim.eval_point)?;
247 composite_mle_metas.push(meta);
248 composite_mle_claims.push(claim.clone())
249 }
250 }
251 }
252 let dimensions_data = MLEFoldHighDimensionsData::new(projected_bivariate_claims.len());
253 let evalcheck_mle_fold_high_span = tracing::debug_span!(
254 "[task] (Evalcheck) MLE Fold High",
255 phase = "evalcheck",
256 perfetto_category = "task.main",
257 dimensions_data = ?dimensions_data,
258 )
259 .entered();
260
261 let projected_mles = calculate_projected_mles(
262 &projected_bivariate_metas,
263 &mut self.memoized_data,
264 &projected_bivariate_claims,
265 self.witness_index,
266 self.backend,
267 )?;
268 drop(evalcheck_mle_fold_high_span);
269
270 fill_eq_witness_for_composites(
271 &composite_mle_metas,
272 &mut self.memoized_data,
273 &composite_mle_claims,
274 self.witness_index,
275 self.backend,
276 )?;
277
278 let mut projected_index = 0;
279 let mut composite_index = 0;
280
281 for claim in std::mem::take(&mut self.sumcheck_claims) {
282 match claim {
283 SumcheckClaims::Projected(claim) => {
284 let meta = &projected_bivariate_metas[projected_index];
285 let projected = projected_mles[projected_index].clone();
286 self.process_bivariate_sumcheck(&claim, meta, projected)?;
287 projected_index += 1;
288 }
289 SumcheckClaims::Composite(claim) => {
290 let meta = composite_mle_metas[composite_index];
291 self.process_composite_mlecheck(&claim, meta)?;
292 composite_index += 1;
293 }
294 }
295 }
296
297 self.memoized_data.memoize_partial_evals(
298 &projected_bivariate_metas,
299 &projected_bivariate_claims,
300 self.oracles,
301 self.witness_index,
302 );
303
304 Ok(())
305 }
306
307 #[instrument(
308 skip_all,
309 name = "EvalcheckProverState::collect_subclaims_for_precompute",
310 level = "debug"
311 )]
312 fn collect_subclaims_for_memoization(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
313 let multilinear_id = evalcheck_claim.id;
314
315 let eval_point = evalcheck_claim.eval_point;
316
317 let eval = evalcheck_claim.eval;
318
319 if self
320 .visited_claims
321 .get(multilinear_id, &eval_point)
322 .is_some()
323 {
324 return;
325 }
326
327 self.visited_claims
328 .insert(multilinear_id, eval_point.clone(), ());
329
330 let multilinear = self.oracles.oracle(multilinear_id);
331
332 match multilinear.variant {
333 MultilinearPolyVariant::Repeating { id, .. } => {
334 let n_vars = self.oracles.n_vars(id);
335 let inner_eval_point = eval_point.slice(0..n_vars);
336 let subclaim = EvalcheckMultilinearClaim {
337 id,
338 eval_point: inner_eval_point,
339 eval,
340 };
341 self.claims_queue.push(subclaim);
342 }
343
344 MultilinearPolyVariant::Projected(projected) => {
345 let (id, values) = (projected.id(), projected.values());
346 let new_eval_point = {
347 let idx = projected.start_index();
348 let mut new_eval_point = eval_point[0..idx].to_vec();
349 new_eval_point.extend(values.clone());
350 new_eval_point.extend(eval_point[idx..].to_vec());
351 new_eval_point
352 };
353
354 let subclaim = EvalcheckMultilinearClaim {
355 id,
356 eval_point: new_eval_point.into(),
357 eval,
358 };
359 self.claims_queue.push(subclaim);
360 }
361
362 MultilinearPolyVariant::LinearCombination(linear_combination) => {
363 let n_polys = linear_combination.n_polys();
364
365 match linear_combination
366 .polys()
367 .zip(linear_combination.coefficients())
368 .next()
369 {
370 Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
371 let eval = if let Some(eval) =
372 self.evals_memoization.get(suboracle_id, &eval_point)
373 {
374 *eval
375 } else {
376 let eval = (eval - linear_combination.offset())
377 * coeff.invert().expect("not zero");
378 self.evals_memoization
379 .insert(suboracle_id, eval_point.clone(), eval);
380 eval
381 };
382
383 let subclaim = EvalcheckMultilinearClaim {
384 id: suboracle_id,
385 eval_point,
386 eval,
387 };
388 self.claims_queue.push(subclaim);
389 }
390 _ => {
391 for suboracle_id in linear_combination.polys() {
392 self.claims_without_evals
393 .push((self.oracles.oracle(suboracle_id), eval_point.clone()));
394 }
395 }
396 };
397 }
398
399 MultilinearPolyVariant::ZeroPadded(padded) => {
400 let id = padded.id();
401 let inner = self.oracles.oracle(id);
402 let inner_eval_point = chain!(
403 &eval_point[..padded.start_index()],
404 &eval_point[padded.start_index() + padded.n_pad_vars()..],
405 )
406 .copied()
407 .collect::<Vec<_>>();
408 self.claims_without_evals
409 .push((inner, inner_eval_point.into()));
410 }
411 _ => return,
412 };
413 }
414
415 #[instrument(
416 skip_all,
417 name = "EvalcheckProverState::prove_multilinear",
418 level = "debug"
419 )]
420 fn prove_multilinear<Challenger_: Challenger>(
421 &mut self,
422 evalcheck_claim: EvalcheckMultilinearClaim<F>,
423 transcript: &mut ProverTranscript<Challenger_>,
424 ) -> Result<(), Error> {
425 let EvalcheckMultilinearClaim { id, eval_point, .. } = &evalcheck_claim;
426
427 let claim_id = self.claim_to_index.get(*id, eval_point);
428
429 if let Some(claim_id) = claim_id {
430 serialize_evalcheck_proof(
431 &mut transcript.message(),
432 &EvalcheckHint::DuplicateClaim(*claim_id as u32),
433 );
434 return Ok(());
435 }
436 serialize_evalcheck_proof(&mut transcript.message(), &EvalcheckHint::NewClaim);
437
438 self.prove_multilinear_skip_duplicate_check(evalcheck_claim, transcript)
439 }
440
441 fn prove_multilinear_skip_duplicate_check<Challenger_: Challenger>(
442 &mut self,
443 evalcheck_claim: EvalcheckMultilinearClaim<F>,
444 transcript: &mut ProverTranscript<Challenger_>,
445 ) -> Result<(), Error> {
446 let EvalcheckMultilinearClaim {
447 id,
448 eval_point,
449 eval,
450 } = evalcheck_claim;
451
452 self.claim_to_index
453 .insert(id, eval_point.clone(), self.round_claim_index);
454
455 self.round_claim_index += 1;
456
457 let multilinear = self.oracles.oracle(id);
458
459 match multilinear.variant {
460 MultilinearPolyVariant::Transparent { .. } => {}
461 MultilinearPolyVariant::Committed => {
462 self.committed_eval_claims.push(EvalcheckMultilinearClaim {
463 id: multilinear.id,
464 eval_point,
465 eval,
466 });
467 }
468 MultilinearPolyVariant::Repeating {
469 id: inner_id,
470 log_count,
471 } => {
472 let n_vars = eval_point.len() - log_count;
473 self.prove_multilinear(
474 EvalcheckMultilinearClaim {
475 id: inner_id,
476 eval_point: eval_point.slice(0..n_vars),
477 eval,
478 },
479 transcript,
480 )?;
481 }
482 MultilinearPolyVariant::Projected(projected) => {
483 let new_eval_point = {
484 let (lo, hi) = eval_point.split_at(projected.start_index());
485 chain!(lo, projected.values(), hi)
486 .copied()
487 .collect::<Vec<_>>()
488 };
489
490 self.prove_multilinear(
491 EvalcheckMultilinearClaim {
492 id: projected.id(),
493 eval_point: new_eval_point.into(),
494 eval,
495 },
496 transcript,
497 )?;
498 }
499 MultilinearPolyVariant::Shifted { .. } | MultilinearPolyVariant::Packed { .. } => {
500 let claim = EvalcheckMultilinearClaim {
501 id,
502 eval_point,
503 eval,
504 };
505
506 self.sumcheck_claims.push(SumcheckClaims::Projected(claim));
507 }
508 MultilinearPolyVariant::Composite { .. } => {
509 let claim = EvalcheckMultilinearClaim {
510 id,
511 eval_point,
512 eval,
513 };
514
515 self.sumcheck_claims.push(SumcheckClaims::Composite(claim));
516 }
517 MultilinearPolyVariant::LinearCombination(linear_combination) => {
518 for suboracle_id in linear_combination.polys() {
519 if let Some(claim_index) = self.claim_to_index.get(suboracle_id, &eval_point) {
520 serialize_evalcheck_proof(
521 &mut transcript.message(),
522 &EvalcheckHint::DuplicateClaim(*claim_index as u32),
523 );
524 } else {
525 serialize_evalcheck_proof(
526 &mut transcript.message(),
527 &EvalcheckHint::NewClaim,
528 );
529
530 let eval = *self
531 .evals_memoization
532 .get(suboracle_id, &eval_point)
533 .expect("precomputed above");
534
535 transcript.message().write_scalar(eval);
536
537 self.prove_multilinear_skip_duplicate_check(
538 EvalcheckMultilinearClaim {
539 id: suboracle_id,
540 eval_point: eval_point.clone(),
541 eval,
542 },
543 transcript,
544 )?;
545 }
546 }
547 }
548 MultilinearPolyVariant::ZeroPadded(padded) => {
549 let inner_eval_point = chain!(
550 &eval_point[..padded.start_index()],
551 &eval_point[padded.start_index() + padded.n_pad_vars()..],
552 )
553 .copied()
554 .collect::<Vec<_>>();
555
556 let inner_eval = *self
557 .evals_memoization
558 .get(padded.id(), &inner_eval_point)
559 .expect("precomputed above");
560
561 self.prove_multilinear(
562 EvalcheckMultilinearClaim {
563 id: padded.id(),
564 eval_point: inner_eval_point.into(),
565 eval: inner_eval,
566 },
567 transcript,
568 )?;
569 }
570 }
571 Ok(())
572 }
573
574 fn projected_bivariate_meta(
575 oracles: &mut MultilinearOracleSet<F>,
576 evalcheck_claim: &EvalcheckMultilinearClaim<F>,
577 ) -> Result<ProjectedBivariateMeta, Error> {
578 let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
579
580 match &oracles.oracle(*id).variant {
581 MultilinearPolyVariant::Shifted(shifted) => {
582 shifted_sumcheck_meta(oracles, shifted, eval_point)
583 }
584 MultilinearPolyVariant::Packed(packed) => {
585 packed_sumcheck_meta(oracles, packed, eval_point)
586 }
587 _ => unreachable!(),
588 }
589 }
590
591 fn process_bivariate_sumcheck(
592 &mut self,
593 evalcheck_claim: &EvalcheckMultilinearClaim<F>,
594 meta: &ProjectedBivariateMeta,
595 projected: Option<MultilinearExtension<P>>,
596 ) -> Result<(), Error> {
597 let EvalcheckMultilinearClaim {
598 id,
599 eval_point,
600 eval,
601 } = evalcheck_claim;
602
603 match self.oracles.oracle(*id).variant {
604 MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
605 &shifted,
606 meta,
607 eval_point,
608 *eval,
609 self.witness_index,
610 &mut self.new_sumchecks_constraints,
611 projected,
612 ),
613
614 MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
615 self.oracles,
616 &packed,
617 meta,
618 eval_point,
619 *eval,
620 self.witness_index,
621 &mut self.new_sumchecks_constraints,
622 projected,
623 ),
624
625 _ => unreachable!(),
626 }
627 }
628
629 fn process_composite_mlecheck(
630 &mut self,
631 evalcheck_claim: &EvalcheckMultilinearClaim<F>,
632 meta: CompositeMLECheckMeta,
633 ) -> Result<(), Error> {
634 let EvalcheckMultilinearClaim {
635 id,
636 eval_point: _,
637 eval,
638 } = evalcheck_claim;
639
640 match self.oracles.oracle(*id).variant {
641 MultilinearPolyVariant::Composite(composite) => {
642 add_composite_sumcheck_to_constraints(
644 meta,
645 &mut self.new_sumchecks_constraints,
646 &composite,
647 *eval,
648 );
649 Ok(())
650 }
651 _ => unreachable!(),
652 }
653 }
654
655 #[instrument(
657 skip_all,
658 name = "EvalcheckProverState::make_new_eval_claim",
659 level = "debug"
660 )]
661 fn make_new_eval_claim(
662 oracle_id: OracleId,
663 eval_point: EvalPoint<F>,
664 witness_index: &MultilinearExtensionIndex<P>,
665 memoized_queries: &MemoizedData<P, Backend>,
666 ) -> Result<EvalcheckMultilinearClaim<F>, Error> {
667 let eval_query = memoized_queries
668 .full_query_readonly(&eval_point)
669 .ok_or(Error::MissingQuery)?;
670
671 let witness_poly = witness_index
672 .get_multilin_poly(oracle_id)
673 .map_err(Error::Witness)?;
674
675 let eval = witness_poly
676 .evaluate(eval_query.to_ref())
677 .map_err(Error::from)?;
678
679 Ok(EvalcheckMultilinearClaim {
680 id: oracle_id,
681 eval_point,
682 eval,
683 })
684 }
685}