1use binius_field::{PackedFieldIndexable, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_maybe_rayon::prelude::*;
7use getset::{Getters, MutGetters};
8use itertools::izip;
9use tracing::instrument;
10
11use super::{
12 error::Error,
13 evalcheck::{EvalcheckMultilinearClaim, EvalcheckProof},
14 subclaims::{
15 add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_sumcheck_meta,
16 fill_eq_witness_for_composites, MemoizedData, ProjectedBivariateMeta,
17 },
18 EvalPoint, EvalPointOracleIdMap,
19};
20use crate::{
21 oracle::{
22 ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
23 MultilinearPolyOracle, MultilinearPolyVariant, OracleId,
24 },
25 protocols::evalcheck::subclaims::{
26 packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
27 shifted_sumcheck_meta,
28 },
29 witness::MultilinearExtensionIndex,
30};
31
32#[derive(Getters, MutGetters)]
38pub struct EvalcheckProver<'a, 'b, F, P, Backend>
39where
40 P: PackedFieldIndexable<Scalar = F>,
41 F: TowerField,
42 Backend: ComputationBackend,
43{
44 pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
45 pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
46
47 #[getset(get = "pub", get_mut = "pub")]
48 committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
49
50 finalized_proofs: EvalPointOracleIdMap<(F, EvalcheckProof<F>), F>,
51
52 claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
53 incomplete_proof_claims: EvalPointOracleIdMap<EvalcheckMultilinearClaim<F>, F>,
54 claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
55 claims_without_evals_dedup: EvalPointOracleIdMap<(), F>,
56 projected_bivariate_claims: Vec<EvalcheckMultilinearClaim<F>>,
57
58 new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
59 pub memoized_data: MemoizedData<'b, P, Backend>,
60 backend: &'a Backend,
61}
62
63impl<'a, 'b, F, P, Backend> EvalcheckProver<'a, 'b, F, P, Backend>
64where
65 P: PackedFieldIndexable<Scalar = F>,
66 F: TowerField,
67 Backend: ComputationBackend,
68{
69 pub fn new(
73 oracles: &'a mut MultilinearOracleSet<F>,
74 witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
75 backend: &'a Backend,
76 ) -> Self {
77 Self {
78 oracles,
79 witness_index,
80 committed_eval_claims: Vec::new(),
81 new_sumchecks_constraints: Vec::new(),
82 finalized_proofs: EvalPointOracleIdMap::new(),
83 claims_queue: Vec::new(),
84 claims_without_evals: Vec::new(),
85 claims_without_evals_dedup: EvalPointOracleIdMap::new(),
86 projected_bivariate_claims: Vec::new(),
87 memoized_data: MemoizedData::new(),
88 backend,
89 incomplete_proof_claims: EvalPointOracleIdMap::new(),
90 }
91 }
92
93 pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
95 self.new_sumchecks_constraints
96 .iter_mut()
97 .map(|builder| std::mem::take(builder).build_one(self.oracles))
98 .filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
99 .rev()
100 .collect()
101 }
102
103 #[instrument(skip_all, name = "EvalcheckProver::prove", level = "debug")]
118 pub fn prove(
119 &mut self,
120 evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
121 ) -> Result<Vec<EvalcheckProof<F>>, Error> {
122 for claim in &evalcheck_claims {
123 self.claims_without_evals_dedup
124 .insert(claim.id, claim.eval_point.clone(), ());
125 }
126
127 self.claims_queue.extend(evalcheck_claims.clone());
129
130 while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
134 while !self.claims_queue.is_empty() {
136 std::mem::take(&mut self.claims_queue)
137 .into_iter()
138 .for_each(|claim| self.prove_multilinear(claim));
139 }
140
141 let mut deduplicated_claims_without_evals = Vec::new();
142
143 for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
144 if self.finalized_proofs.get(poly.id(), &eval_point).is_some() {
145 continue;
146 }
147 if self
148 .claims_without_evals_dedup
149 .get(poly.id(), &eval_point)
150 .is_some()
151 {
152 continue;
153 }
154
155 self.claims_without_evals_dedup
156 .insert(poly.id(), eval_point.clone(), ());
157
158 deduplicated_claims_without_evals.push((poly, eval_point.clone()))
159 }
160
161 let deduplicated_eval_points = deduplicated_claims_without_evals
162 .iter()
163 .map(|(_, eval_point)| eval_point.as_ref())
164 .collect::<Vec<_>>();
165
166 self.memoized_data
167 .memoize_query_par(&deduplicated_eval_points, self.backend)?;
168
169 let subclaims = deduplicated_claims_without_evals
171 .into_par_iter()
172 .map(|(poly, eval_point)| {
173 Self::make_new_eval_claim(
174 poly.id(),
175 eval_point,
176 self.witness_index,
177 &self.memoized_data,
178 )
179 })
180 .collect::<Result<Vec<_>, Error>>()?;
181
182 subclaims
183 .into_iter()
184 .for_each(|claim| self.prove_multilinear(claim));
185 }
186
187 let mut incomplete_proof_claims =
188 std::mem::take(&mut self.incomplete_proof_claims).flatten();
189
190 while !incomplete_proof_claims.is_empty() {
191 for claim in std::mem::take(&mut incomplete_proof_claims) {
192 if self.complete_proof(&claim) {
193 continue;
194 }
195 incomplete_proof_claims.push(claim);
196 }
197 }
198
199 evalcheck_claims
205 .iter()
206 .cloned()
207 .for_each(|claim| self.collect_projected_committed(claim));
208
209 let projected_bivariate_metas = self
212 .projected_bivariate_claims
213 .iter()
214 .map(|claim| Self::projected_bivariate_meta(self.oracles, claim))
215 .collect::<Result<Vec<_>, Error>>()?;
216
217 let projected_mles = calculate_projected_mles(
218 &projected_bivariate_metas,
219 &mut self.memoized_data,
220 &self.projected_bivariate_claims,
221 self.witness_index,
222 self.backend,
223 )?;
224
225 fill_eq_witness_for_composites(
226 &projected_bivariate_metas,
227 &mut self.memoized_data,
228 &self.projected_bivariate_claims,
229 self.witness_index,
230 self.backend,
231 )?;
232
233 for (claim, meta, projected) in izip!(
234 std::mem::take(&mut self.projected_bivariate_claims),
235 &projected_bivariate_metas,
236 projected_mles
237 ) {
238 self.process_sumcheck(claim, meta, projected)?;
239 }
240
241 self.memoized_data.memoize_partial_evals(
242 &projected_bivariate_metas,
243 &self.projected_bivariate_claims,
244 self.oracles,
245 self.witness_index,
246 );
247
248 Ok(evalcheck_claims
251 .iter()
252 .map(|claim| {
253 self.finalized_proofs
254 .get(claim.id, &claim.eval_point)
255 .map(|(_, proof)| proof.clone())
256 .expect("finalized_proofs contains all the proofs")
257 })
258 .collect::<Vec<_>>())
259 }
260
261 #[instrument(
262 skip_all,
263 name = "EvalcheckProverState::prove_multilinear",
264 level = "debug"
265 )]
266 fn prove_multilinear(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
267 let multilinear_id = evalcheck_claim.id;
268
269 let eval_point = evalcheck_claim.eval_point.clone();
270
271 let eval = evalcheck_claim.eval;
272
273 if self
274 .finalized_proofs
275 .get(multilinear_id, &eval_point)
276 .is_some()
277 {
278 return;
279 }
280
281 if self
282 .incomplete_proof_claims
283 .get(multilinear_id, &eval_point)
284 .is_some()
285 {
286 return;
287 }
288
289 let multilinear = self.oracles.oracle(multilinear_id);
290
291 match multilinear.variant {
292 MultilinearPolyVariant::Transparent { .. } => {
293 self.finalized_proofs.insert(
294 multilinear_id,
295 eval_point,
296 (eval, EvalcheckProof::Transparent),
297 );
298 }
299
300 MultilinearPolyVariant::Committed => {
301 self.finalized_proofs.insert(
302 multilinear_id,
303 eval_point,
304 (eval, EvalcheckProof::Committed),
305 );
306 }
307
308 MultilinearPolyVariant::Repeating { id, .. } => {
309 let n_vars = self.oracles.n_vars(id);
310 let inner_eval_point = eval_point.slice(0..n_vars);
311 let subclaim = EvalcheckMultilinearClaim {
312 id,
313 eval_point: inner_eval_point,
314 eval,
315 };
316 self.incomplete_proof_claims
317 .insert(multilinear_id, eval_point, evalcheck_claim);
318 self.claims_queue.push(subclaim);
319 }
320
321 MultilinearPolyVariant::Shifted { .. } => {
322 self.finalized_proofs.insert(
323 multilinear_id,
324 eval_point,
325 (eval, EvalcheckProof::Shifted),
326 );
327 }
328
329 MultilinearPolyVariant::Packed { .. } => {
330 self.finalized_proofs.insert(
331 multilinear_id,
332 eval_point,
333 (eval, EvalcheckProof::Packed),
334 );
335 }
336
337 MultilinearPolyVariant::Composite(_) => {
338 self.finalized_proofs.insert(
339 multilinear_id,
340 eval_point,
341 (eval, EvalcheckProof::CompositeMLE),
342 );
343 }
344
345 MultilinearPolyVariant::Projected(projected) => {
346 let (id, values) = (projected.id(), projected.values());
347 let new_eval_point = {
348 let idx = projected.start_index();
349 let mut new_eval_point = eval_point[0..idx].to_vec();
350 new_eval_point.extend(values.clone());
351 new_eval_point.extend(eval_point[idx..].to_vec());
352 new_eval_point
353 };
354
355 let subclaim = EvalcheckMultilinearClaim {
356 id,
357 eval_point: new_eval_point.into(),
358 eval,
359 };
360 self.incomplete_proof_claims
361 .insert(multilinear_id, eval_point, evalcheck_claim);
362 self.claims_queue.push(subclaim);
363 }
364
365 MultilinearPolyVariant::LinearCombination(linear_combination) => {
366 let n_polys = linear_combination.n_polys();
367
368 match linear_combination
369 .polys()
370 .zip(linear_combination.coefficients())
371 .next()
372 {
373 Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
374 let eval = (eval - linear_combination.offset())
375 * coeff.invert().expect("not zero");
376 let subclaim = EvalcheckMultilinearClaim {
377 id: suboracle_id,
378 eval_point: eval_point.clone(),
379 eval,
380 };
381 self.claims_queue.push(subclaim);
382 }
383 _ => {
384 for suboracle_id in linear_combination.polys() {
385 self.claims_without_evals
386 .push((self.oracles.oracle(suboracle_id), eval_point.clone()));
387 }
388 }
389 };
390
391 self.incomplete_proof_claims
392 .insert(multilinear_id, eval_point, evalcheck_claim);
393 }
394
395 MultilinearPolyVariant::ZeroPadded(id) => {
396 let inner = self.oracles.oracle(id);
397 let inner_n_vars = inner.n_vars();
398 let inner_eval_point = eval_point.slice(0..inner_n_vars);
399 self.claims_without_evals.push((inner, inner_eval_point));
400 self.incomplete_proof_claims
401 .insert(multilinear_id, eval_point, evalcheck_claim);
402 }
403 };
404 }
405
406 fn complete_proof(&mut self, evalcheck_claim: &EvalcheckMultilinearClaim<F>) -> bool {
407 let id = &evalcheck_claim.id;
408 let eval_point = evalcheck_claim.eval_point.clone();
409 let eval = evalcheck_claim.eval;
410
411 let res = match self.oracles.oracle(*id).variant {
412 MultilinearPolyVariant::Repeating { id, .. } => {
413 let n_vars = self.oracles.n_vars(id);
414 let inner_eval_point = &evalcheck_claim.eval_point[..n_vars];
415 self.finalized_proofs
416 .get(id, inner_eval_point)
417 .map(|(_, subproof)| subproof.clone())
418 .map(move |subproof| {
419 let proof = EvalcheckProof::Repeating(Box::new(subproof));
420 self.finalized_proofs
421 .insert(evalcheck_claim.id, eval_point, (eval, proof));
422 })
423 }
424 MultilinearPolyVariant::Projected(projected) => {
425 let (id, values) = (projected.id(), projected.values());
426 let new_eval_point = {
427 let idx = projected.start_index();
428 let mut new_eval_point = eval_point[0..idx].to_vec();
429 new_eval_point.extend(values.clone());
430 new_eval_point.extend(eval_point[idx..].to_vec());
431 new_eval_point
432 };
433
434 self.finalized_proofs
435 .get(id, &new_eval_point)
436 .map(|(_, subproof)| subproof.clone())
437 .map(|subproof| {
438 self.finalized_proofs.insert(
439 evalcheck_claim.id,
440 eval_point,
441 (eval, subproof),
442 );
443 })
444 }
445
446 MultilinearPolyVariant::LinearCombination(linear_combination) => linear_combination
447 .polys()
448 .map(|suboracle_id| {
449 self.finalized_proofs
450 .get(suboracle_id, &evalcheck_claim.eval_point)
451 .map(|(eval, subproof)| (*eval, subproof.clone()))
452 })
453 .collect::<Option<Vec<_>>>()
454 .map(|subproofs| {
455 self.finalized_proofs.insert(
456 evalcheck_claim.id,
457 eval_point,
458 (eval, EvalcheckProof::LinearCombination { subproofs }),
459 );
460 }),
461
462 MultilinearPolyVariant::ZeroPadded(inner_id) => {
463 let inner_n_vars = self.oracles.n_vars(inner_id);
464 let inner_eval_point = &evalcheck_claim.eval_point[..inner_n_vars];
465 self.finalized_proofs
466 .get(inner_id, inner_eval_point)
467 .map(|(eval, subproof)| (*eval, subproof.clone()))
468 .map(|(internal_eval, subproof)| {
469 self.finalized_proofs.insert(
470 evalcheck_claim.id,
471 eval_point,
472 (eval, EvalcheckProof::ZeroPadded(internal_eval, Box::new(subproof))),
473 );
474 })
475 }
476
477 _ => unreachable!(),
478 };
479 res.is_some()
480 }
481
482 fn collect_projected_committed(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
483 let EvalcheckMultilinearClaim {
484 id,
485 eval_point,
486 eval,
487 } = evalcheck_claim.clone();
488
489 let multilinear = self.oracles.oracle(id);
490 match multilinear.variant {
491 MultilinearPolyVariant::Committed => {
492 let subclaim = EvalcheckMultilinearClaim {
493 id: multilinear.id,
494 eval_point,
495 eval,
496 };
497
498 self.committed_eval_claims.push(subclaim);
499 }
500 MultilinearPolyVariant::Repeating { id, .. } => {
501 let n_vars = self.oracles.n_vars(id);
502 let inner_eval_point = eval_point.slice(0..n_vars);
503 let subclaim = EvalcheckMultilinearClaim {
504 id,
505 eval_point: inner_eval_point,
506 eval,
507 };
508
509 self.collect_projected_committed(subclaim);
510 }
511 MultilinearPolyVariant::Projected(projected) => {
512 let (id, values) = (projected.id(), projected.values());
513 let new_eval_point = {
514 let idx = projected.start_index();
515 let mut new_eval_point = eval_point[0..idx].to_vec();
516 new_eval_point.extend(values.clone());
517 new_eval_point.extend(eval_point[idx..].to_vec());
518 new_eval_point
519 };
520
521 let subclaim = EvalcheckMultilinearClaim {
522 id,
523 eval_point: new_eval_point.into(),
524 eval,
525 };
526 self.collect_projected_committed(subclaim);
527 }
528 MultilinearPolyVariant::Shifted { .. }
529 | MultilinearPolyVariant::Packed { .. }
530 | MultilinearPolyVariant::Composite { .. } => {
531 self.projected_bivariate_claims.push(evalcheck_claim)
532 }
533 MultilinearPolyVariant::LinearCombination(linear_combination) => {
534 for id in linear_combination.polys() {
535 let (eval, _) = self
536 .finalized_proofs
537 .get(id, &eval_point)
538 .expect("finalized_proofs contains all the proofs");
539 let subclaim = EvalcheckMultilinearClaim {
540 id,
541 eval_point: eval_point.clone(),
542 eval: *eval,
543 };
544 self.collect_projected_committed(subclaim);
545 }
546 }
547 MultilinearPolyVariant::ZeroPadded(id) => {
548 let inner_n_vars = self.oracles.n_vars(id);
549 let inner_eval_point = eval_point.slice(0..inner_n_vars);
550
551 let (eval, _) = self
552 .finalized_proofs
553 .get(id, &inner_eval_point)
554 .expect("finalized_proofs contains all the proofs");
555
556 let subclaim = EvalcheckMultilinearClaim {
557 id,
558 eval_point,
559 eval: *eval,
560 };
561 self.collect_projected_committed(subclaim);
562 }
563 _ => {}
564 }
565 }
566
567 fn projected_bivariate_meta(
568 oracles: &mut MultilinearOracleSet<F>,
569 evalcheck_claim: &EvalcheckMultilinearClaim<F>,
570 ) -> Result<ProjectedBivariateMeta, Error> {
571 let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
572
573 match &oracles.oracle(*id).variant {
574 MultilinearPolyVariant::Shifted(shifted) => {
575 shifted_sumcheck_meta(oracles, shifted, eval_point)
576 }
577 MultilinearPolyVariant::Packed(packed) => {
578 packed_sumcheck_meta(oracles, packed, eval_point)
579 }
580 MultilinearPolyVariant::Composite(_) => composite_sumcheck_meta(oracles, eval_point),
581 _ => unreachable!(),
582 }
583 }
584
585 fn process_sumcheck(
586 &mut self,
587 evalcheck_claim: EvalcheckMultilinearClaim<F>,
588 meta: &ProjectedBivariateMeta,
589 projected: Option<MultilinearExtension<P>>,
590 ) -> Result<(), Error> {
591 let EvalcheckMultilinearClaim {
592 id,
593 eval_point,
594 eval,
595 } = evalcheck_claim;
596
597 match self.oracles.oracle(id).variant {
598 MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
599 &shifted,
600 meta,
601 &eval_point,
602 eval,
603 self.witness_index,
604 &mut self.new_sumchecks_constraints,
605 projected,
606 ),
607
608 MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
609 self.oracles,
610 &packed,
611 meta,
612 &eval_point,
613 eval,
614 self.witness_index,
615 &mut self.new_sumchecks_constraints,
616 projected,
617 ),
618
619 MultilinearPolyVariant::Composite(composite) => {
620 add_composite_sumcheck_to_constraints(
622 meta,
623 &mut self.new_sumchecks_constraints,
624 &composite,
625 eval,
626 );
627 Ok(())
628 }
629 _ => unreachable!(),
630 }
631 }
632
633 fn make_new_eval_claim(
634 oracle_id: OracleId,
635 eval_point: EvalPoint<F>,
636 witness_index: &MultilinearExtensionIndex<P>,
637 memoized_queries: &MemoizedData<P, Backend>,
638 ) -> Result<EvalcheckMultilinearClaim<F>, Error> {
639 let eval_query = memoized_queries
640 .full_query_readonly(&eval_point)
641 .ok_or(Error::MissingQuery)?;
642
643 let witness_poly = witness_index
644 .get_multilin_poly(oracle_id)
645 .map_err(Error::Witness)?;
646
647 let eval = witness_poly
648 .evaluate(eval_query.to_ref())
649 .map_err(Error::from)?;
650
651 Ok(EvalcheckMultilinearClaim {
652 id: oracle_id,
653 eval_point,
654 eval,
655 })
656 }
657}