1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 underlier::UnderlierType,
6 PackedFieldIndexable, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::MultilinearExtension;
10use binius_maybe_rayon::prelude::*;
11use getset::{Getters, MutGetters};
12use itertools::izip;
13use tracing::instrument;
14
15use super::{
16 error::Error,
17 evalcheck::{EvalcheckMultilinearClaim, EvalcheckProof},
18 subclaims::{
19 add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_sumcheck_meta,
20 fill_eq_witness_for_composites, MemoizedQueries, ProjectedBivariateMeta,
21 },
22 EvalPoint, EvalPointOracleIdMap,
23};
24use crate::{
25 oracle::{
26 ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
27 MultilinearPolyOracle, MultilinearPolyVariant, OracleId, ProjectionVariant,
28 },
29 protocols::evalcheck::subclaims::{
30 packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
31 shifted_sumcheck_meta,
32 },
33 witness::MultilinearExtensionIndex,
34};
35
36#[derive(Getters, MutGetters)]
42pub struct EvalcheckProver<'a, 'b, U, F, Backend>
43where
44 U: UnderlierType + PackScalar<F>,
45 F: TowerField,
46 Backend: ComputationBackend,
47{
48 pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
49 pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
50
51 #[getset(get = "pub", get_mut = "pub")]
52 committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
53
54 finalized_proofs: EvalPointOracleIdMap<(F, EvalcheckProof<F>), F>,
55
56 claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
57 incomplete_proof_claims: EvalPointOracleIdMap<EvalcheckMultilinearClaim<F>, F>,
58 #[allow(clippy::type_complexity)]
59 claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
60 claims_without_evals_dedup: EvalPointOracleIdMap<(), F>,
61 projected_bivariate_claims: Vec<EvalcheckMultilinearClaim<F>>,
62
63 new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
64 memoized_queries: MemoizedQueries<PackedType<U, F>, Backend>,
65 backend: &'a Backend,
66}
67
68impl<'a, 'b, U, F, Backend> EvalcheckProver<'a, 'b, U, F, Backend>
69where
70 U: UnderlierType + PackScalar<F>,
71 PackedType<U, F>: PackedFieldIndexable,
72 F: TowerField,
73 Backend: ComputationBackend,
74{
75 pub fn new(
79 oracles: &'a mut MultilinearOracleSet<F>,
80 witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
81 backend: &'a Backend,
82 ) -> Self {
83 Self {
84 oracles,
85 witness_index,
86 committed_eval_claims: Vec::new(),
87 new_sumchecks_constraints: Vec::new(),
88 finalized_proofs: EvalPointOracleIdMap::new(),
89 claims_queue: Vec::new(),
90 claims_without_evals: Vec::new(),
91 claims_without_evals_dedup: EvalPointOracleIdMap::new(),
92 projected_bivariate_claims: Vec::new(),
93 memoized_queries: MemoizedQueries::new(),
94 backend,
95 incomplete_proof_claims: EvalPointOracleIdMap::new(),
96 }
97 }
98
99 pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
101 self.new_sumchecks_constraints
102 .iter_mut()
103 .map(|builder| std::mem::take(builder).build_one(self.oracles))
104 .filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
105 .rev()
106 .collect()
107 }
108
109 #[instrument(skip_all, name = "EvalcheckProver::prove", level = "debug")]
124 pub fn prove(
125 &mut self,
126 evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
127 ) -> Result<Vec<EvalcheckProof<F>>, Error> {
128 for claim in &evalcheck_claims {
129 self.claims_without_evals_dedup
130 .insert(claim.id, claim.eval_point.clone(), ());
131 }
132
133 self.claims_queue.extend(evalcheck_claims.clone());
135
136 while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
140 while !self.claims_queue.is_empty() {
142 std::mem::take(&mut self.claims_queue)
143 .into_iter()
144 .for_each(|claim| self.prove_multilinear(claim));
145 }
146
147 let mut deduplicated_claims_without_evals = Vec::new();
148
149 for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
150 if self.finalized_proofs.get(poly.id(), &eval_point).is_some() {
151 continue;
152 }
153 if self
154 .claims_without_evals_dedup
155 .get(poly.id(), &eval_point)
156 .is_some()
157 {
158 continue;
159 }
160
161 self.claims_without_evals_dedup
162 .insert(poly.id(), eval_point.clone(), ());
163
164 deduplicated_claims_without_evals.push((poly, eval_point.clone()))
165 }
166
167 let deduplicated_eval_points = deduplicated_claims_without_evals
168 .iter()
169 .map(|(_, eval_point)| eval_point.as_ref())
170 .collect::<Vec<_>>();
171
172 self.memoized_queries
173 .memoize_query_par(&deduplicated_eval_points, self.backend)?;
174
175 let subclaims = deduplicated_claims_without_evals
177 .into_par_iter()
178 .map(|(poly, eval_point)| {
179 Self::make_new_eval_claim(
180 poly.id(),
181 eval_point,
182 self.witness_index,
183 &self.memoized_queries,
184 )
185 })
186 .collect::<Result<Vec<_>, Error>>()?;
187
188 subclaims
189 .into_iter()
190 .for_each(|claim| self.prove_multilinear(claim));
191 }
192
193 let mut incomplete_proof_claims =
194 std::mem::take(&mut self.incomplete_proof_claims).flatten();
195
196 while !incomplete_proof_claims.is_empty() {
197 for claim in std::mem::take(&mut incomplete_proof_claims) {
198 if self.complete_proof(&claim) {
199 continue;
200 }
201 incomplete_proof_claims.push(claim);
202 }
203 }
204
205 evalcheck_claims
211 .iter()
212 .cloned()
213 .for_each(|claim| self.collect_projected_committed(claim));
214
215 let projected_bivariate_metas = self
218 .projected_bivariate_claims
219 .iter()
220 .map(|claim| Self::projected_bivariate_meta(self.oracles, claim))
221 .collect::<Result<Vec<_>, Error>>()?;
222
223 let projected_mles = calculate_projected_mles(
224 &projected_bivariate_metas,
225 &mut self.memoized_queries,
226 &self.projected_bivariate_claims,
227 self.witness_index,
228 self.backend,
229 )?;
230
231 fill_eq_witness_for_composites(
232 &projected_bivariate_metas,
233 &mut self.memoized_queries,
234 &self.projected_bivariate_claims,
235 self.witness_index,
236 self.backend,
237 )?;
238
239 for (claim, meta, projected) in izip!(
240 std::mem::take(&mut self.projected_bivariate_claims),
241 projected_bivariate_metas,
242 projected_mles
243 ) {
244 self.process_sumcheck(claim, meta, projected)?;
245 }
246
247 Ok(evalcheck_claims
250 .iter()
251 .map(|claim| {
252 self.finalized_proofs
253 .get(claim.id, &claim.eval_point)
254 .map(|(_, proof)| proof.clone())
255 .expect("finalized_proofs contains all the proofs")
256 })
257 .collect::<Vec<_>>())
258 }
259
260 #[instrument(
261 skip_all,
262 name = "EvalcheckProverState::prove_multilinear",
263 level = "debug"
264 )]
265 fn prove_multilinear(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
266 let multilinear_id = evalcheck_claim.id;
267
268 let eval_point = evalcheck_claim.eval_point.clone();
269
270 let eval = evalcheck_claim.eval;
271
272 if self
273 .finalized_proofs
274 .get(multilinear_id, &eval_point)
275 .is_some()
276 {
277 return;
278 }
279
280 if self
281 .incomplete_proof_claims
282 .get(multilinear_id, &eval_point)
283 .is_some()
284 {
285 return;
286 }
287
288 let multilinear = self.oracles.oracle(multilinear_id);
289
290 match multilinear.variant {
291 MultilinearPolyVariant::Transparent { .. } => {
292 self.finalized_proofs.insert(
293 multilinear_id,
294 eval_point,
295 (eval, EvalcheckProof::Transparent),
296 );
297 }
298
299 MultilinearPolyVariant::Committed => {
300 self.finalized_proofs.insert(
301 multilinear_id,
302 eval_point,
303 (eval, EvalcheckProof::Committed),
304 );
305 }
306
307 MultilinearPolyVariant::Repeating { id, .. } => {
308 let n_vars = self.oracles.n_vars(id);
309 let inner_eval_point = eval_point.slice(0..n_vars);
310 let subclaim = EvalcheckMultilinearClaim {
311 id,
312 eval_point: inner_eval_point,
313 eval,
314 };
315 self.incomplete_proof_claims
316 .insert(multilinear_id, eval_point, evalcheck_claim);
317 self.claims_queue.push(subclaim);
318 }
319
320 MultilinearPolyVariant::Shifted { .. } => {
321 self.finalized_proofs.insert(
322 multilinear_id,
323 eval_point,
324 (eval, EvalcheckProof::Shifted),
325 );
326 }
327
328 MultilinearPolyVariant::Packed { .. } => {
329 self.finalized_proofs.insert(
330 multilinear_id,
331 eval_point,
332 (eval, EvalcheckProof::Packed),
333 );
334 }
335
336 MultilinearPolyVariant::Composite(_) => {
337 self.finalized_proofs.insert(
338 multilinear_id,
339 eval_point,
340 (eval, EvalcheckProof::CompositeMLE),
341 );
342 }
343
344 MultilinearPolyVariant::Projected(projected) => {
345 let (id, values) = (projected.id(), projected.values());
346 let new_eval_point = match projected.projection_variant() {
347 ProjectionVariant::LastVars => {
348 let mut new_eval_point = eval_point.to_vec();
349 new_eval_point.extend(values);
350 new_eval_point
351 }
352 ProjectionVariant::FirstVars => {
353 values.iter().copied().chain(eval_point.to_vec()).collect()
354 }
355 };
356
357 let subclaim = EvalcheckMultilinearClaim {
358 id,
359 eval_point: new_eval_point.into(),
360 eval,
361 };
362 self.incomplete_proof_claims
363 .insert(multilinear_id, eval_point, evalcheck_claim);
364 self.claims_queue.push(subclaim);
365 }
366
367 MultilinearPolyVariant::LinearCombination(linear_combination) => {
368 let n_polys = linear_combination.n_polys();
369
370 match linear_combination
371 .polys()
372 .zip(linear_combination.coefficients())
373 .next()
374 {
375 Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
376 let eval = (eval - linear_combination.offset())
377 * coeff.invert().expect("not zero");
378 let subclaim = EvalcheckMultilinearClaim {
379 id: suboracle_id,
380 eval_point: eval_point.clone(),
381 eval,
382 };
383 self.claims_queue.push(subclaim);
384 }
385 _ => {
386 for suboracle_id in linear_combination.polys() {
387 self.claims_without_evals
388 .push((self.oracles.oracle(suboracle_id), eval_point.clone()));
389 }
390 }
391 };
392
393 self.incomplete_proof_claims
394 .insert(multilinear_id, eval_point, evalcheck_claim);
395 }
396
397 MultilinearPolyVariant::ZeroPadded(id) => {
398 let inner = self.oracles.oracle(id);
399 let inner_n_vars = inner.n_vars();
400 let inner_eval_point = eval_point.slice(0..inner_n_vars);
401 self.claims_without_evals.push((inner, inner_eval_point));
402 self.incomplete_proof_claims
403 .insert(multilinear_id, eval_point, evalcheck_claim);
404 }
405 };
406 }
407
408 fn complete_proof(&mut self, evalcheck_claim: &EvalcheckMultilinearClaim<F>) -> bool {
409 let id = &evalcheck_claim.id;
410 let eval_point = evalcheck_claim.eval_point.clone();
411 let eval = evalcheck_claim.eval;
412
413 let res = match self.oracles.oracle(*id).variant {
414 MultilinearPolyVariant::Repeating { id, .. } => {
415 let n_vars = self.oracles.n_vars(id);
416 let inner_eval_point = &evalcheck_claim.eval_point[..n_vars];
417 self.finalized_proofs
418 .get(id, inner_eval_point)
419 .map(|(_, subproof)| subproof.clone())
420 .map(move |subproof| {
421 let proof = EvalcheckProof::Repeating(Box::new(subproof));
422 self.finalized_proofs
423 .insert(evalcheck_claim.id, eval_point, (eval, proof));
424 })
425 }
426 MultilinearPolyVariant::Projected(projected) => {
427 let (id, values) = (projected.id(), projected.values());
428 let new_eval_point = match projected.projection_variant() {
429 ProjectionVariant::LastVars => {
430 let mut new_eval_point = eval_point.to_vec();
431 new_eval_point.extend(values);
432 new_eval_point
433 }
434 ProjectionVariant::FirstVars => values
435 .iter()
436 .copied()
437 .chain((*eval_point).to_vec())
438 .collect(),
439 };
440 self.finalized_proofs
441 .get(id, &new_eval_point)
442 .map(|(_, subproof)| subproof.clone())
443 .map(|subproof| {
444 self.finalized_proofs.insert(
445 evalcheck_claim.id,
446 eval_point,
447 (eval, subproof),
448 );
449 })
450 }
451
452 MultilinearPolyVariant::LinearCombination(linear_combination) => linear_combination
453 .polys()
454 .map(|suboracle_id| {
455 self.finalized_proofs
456 .get(suboracle_id, &evalcheck_claim.eval_point)
457 .map(|(eval, subproof)| (*eval, subproof.clone()))
458 })
459 .collect::<Option<Vec<_>>>()
460 .map(|subproofs| {
461 self.finalized_proofs.insert(
462 evalcheck_claim.id,
463 eval_point,
464 (eval, EvalcheckProof::LinearCombination { subproofs }),
465 );
466 }),
467
468 MultilinearPolyVariant::ZeroPadded(inner_id) => {
469 let inner_n_vars = self.oracles.n_vars(inner_id);
470 let inner_eval_point = &evalcheck_claim.eval_point[..inner_n_vars];
471 self.finalized_proofs
472 .get(inner_id, inner_eval_point)
473 .map(|(eval, subproof)| (*eval, subproof.clone()))
474 .map(|(internal_eval, subproof)| {
475 self.finalized_proofs.insert(
476 evalcheck_claim.id,
477 eval_point,
478 (eval, EvalcheckProof::ZeroPadded(internal_eval, Box::new(subproof))),
479 );
480 })
481 }
482
483 _ => unreachable!(),
484 };
485 res.is_some()
486 }
487
488 fn collect_projected_committed(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
489 let EvalcheckMultilinearClaim {
490 id,
491 eval_point,
492 eval,
493 } = evalcheck_claim.clone();
494
495 let multilinear = self.oracles.oracle(id);
496 match multilinear.variant {
497 MultilinearPolyVariant::Committed => {
498 let subclaim = EvalcheckMultilinearClaim {
499 id: multilinear.id,
500 eval_point,
501 eval,
502 };
503
504 self.committed_eval_claims.push(subclaim);
505 }
506 MultilinearPolyVariant::Repeating { id, .. } => {
507 let n_vars = self.oracles.n_vars(id);
508 let inner_eval_point = eval_point.slice(0..n_vars);
509 let subclaim = EvalcheckMultilinearClaim {
510 id,
511 eval_point: inner_eval_point,
512 eval,
513 };
514
515 self.collect_projected_committed(subclaim);
516 }
517 MultilinearPolyVariant::Projected(projected) => {
518 let (id, values) = (projected.id(), projected.values());
519 let new_eval_point = match projected.projection_variant() {
520 ProjectionVariant::LastVars => {
521 let mut new_eval_point = eval_point.to_vec();
522 new_eval_point.extend(values);
523 new_eval_point
524 }
525 ProjectionVariant::FirstVars => {
526 values.iter().copied().chain(eval_point.to_vec()).collect()
527 }
528 };
529
530 let subclaim = EvalcheckMultilinearClaim {
531 id,
532 eval_point: new_eval_point.into(),
533 eval,
534 };
535 self.collect_projected_committed(subclaim);
536 }
537 MultilinearPolyVariant::Shifted { .. }
538 | MultilinearPolyVariant::Packed { .. }
539 | MultilinearPolyVariant::Composite { .. } => {
540 self.projected_bivariate_claims.push(evalcheck_claim)
541 }
542 MultilinearPolyVariant::LinearCombination(linear_combination) => {
543 for id in linear_combination.polys() {
544 let (eval, _) = self
545 .finalized_proofs
546 .get(id, &eval_point)
547 .expect("finalized_proofs contains all the proofs");
548 let subclaim = EvalcheckMultilinearClaim {
549 id,
550 eval_point: eval_point.clone(),
551 eval: *eval,
552 };
553 self.collect_projected_committed(subclaim);
554 }
555 }
556 MultilinearPolyVariant::ZeroPadded(id) => {
557 let inner_n_vars = self.oracles.n_vars(id);
558 let inner_eval_point = eval_point.slice(0..inner_n_vars);
559
560 let (eval, _) = self
561 .finalized_proofs
562 .get(id, &inner_eval_point)
563 .expect("finalized_proofs contains all the proofs");
564
565 let subclaim = EvalcheckMultilinearClaim {
566 id,
567 eval_point,
568 eval: *eval,
569 };
570 self.collect_projected_committed(subclaim);
571 }
572 _ => {}
573 }
574 }
575
576 fn projected_bivariate_meta(
577 oracles: &mut MultilinearOracleSet<F>,
578 evalcheck_claim: &EvalcheckMultilinearClaim<F>,
579 ) -> Result<ProjectedBivariateMeta, Error> {
580 let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
581
582 match &oracles.oracle(*id).variant {
583 MultilinearPolyVariant::Shifted(shifted) => {
584 shifted_sumcheck_meta(oracles, shifted, eval_point)
585 }
586 MultilinearPolyVariant::Packed(packed) => {
587 packed_sumcheck_meta(oracles, packed, eval_point)
588 }
589 MultilinearPolyVariant::Composite(_) => composite_sumcheck_meta(oracles, eval_point),
590 _ => unreachable!(),
591 }
592 }
593
594 fn process_sumcheck(
595 &mut self,
596 evalcheck_claim: EvalcheckMultilinearClaim<F>,
597 meta: ProjectedBivariateMeta,
598 projected: Option<MultilinearExtension<PackedType<U, F>>>,
599 ) -> Result<(), Error> {
600 let EvalcheckMultilinearClaim {
601 id,
602 eval_point,
603 eval,
604 } = evalcheck_claim;
605
606 match self.oracles.oracle(id).variant {
607 MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
608 &shifted,
609 meta,
610 &eval_point,
611 eval,
612 self.witness_index,
613 &mut self.new_sumchecks_constraints,
614 projected.expect("projected is required by shifted oracle"),
615 ),
616
617 MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
618 self.oracles,
619 &packed,
620 meta,
621 &eval_point,
622 eval,
623 self.witness_index,
624 &mut self.new_sumchecks_constraints,
625 projected.expect("projected is required by packed oracle"),
626 ),
627
628 MultilinearPolyVariant::Composite(composite) => {
629 add_composite_sumcheck_to_constraints(
631 meta,
632 &mut self.new_sumchecks_constraints,
633 &composite,
634 eval,
635 );
636 Ok(())
637 }
638 _ => unreachable!(),
639 }
640 }
641
642 fn make_new_eval_claim(
643 oracle_id: OracleId,
644 eval_point: EvalPoint<F>,
645 witness_index: &MultilinearExtensionIndex<U, F>,
646 memoized_queries: &MemoizedQueries<PackedType<U, F>, Backend>,
647 ) -> Result<EvalcheckMultilinearClaim<F>, Error> {
648 let eval_query = memoized_queries
649 .full_query_readonly(&eval_point)
650 .ok_or(Error::MissingQuery)?;
651
652 let witness_poly = witness_index
653 .get_multilin_poly(oracle_id)
654 .map_err(Error::Witness)?;
655
656 let eval = witness_poly
657 .evaluate(eval_query.to_ref())
658 .map_err(Error::from)?;
659
660 Ok(EvalcheckMultilinearClaim {
661 id: oracle_id,
662 eval_point,
663 eval,
664 })
665 }
666}