binius_core/protocols/evalcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	as_packed_field::{PackScalar, PackedType},
5	underlier::UnderlierType,
6	PackedFieldIndexable, TowerField,
7};
8use binius_hal::ComputationBackend;
9use binius_math::MultilinearExtension;
10use binius_maybe_rayon::prelude::*;
11use getset::{Getters, MutGetters};
12use itertools::izip;
13use tracing::instrument;
14
15use super::{
16	error::Error,
17	evalcheck::{EvalcheckMultilinearClaim, EvalcheckProof},
18	subclaims::{
19		add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_sumcheck_meta,
20		fill_eq_witness_for_composites, MemoizedQueries, ProjectedBivariateMeta,
21	},
22	EvalPoint, EvalPointOracleIdMap,
23};
24use crate::{
25	oracle::{
26		ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
27		MultilinearPolyOracle, MultilinearPolyVariant, OracleId, ProjectionVariant,
28	},
29	protocols::evalcheck::subclaims::{
30		packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
31		shifted_sumcheck_meta,
32	},
33	witness::MultilinearExtensionIndex,
34};
35
36/// A mutable prover state.
37///
38/// Can be persisted across [`EvalcheckProver::prove`] invocations. Accumulates
39/// `new_sumchecks` bivariate sumcheck instances, as well as holds mutable references to
40/// the trace (to which new oracles & multilinears may be added during proving)
41#[derive(Getters, MutGetters)]
42pub struct EvalcheckProver<'a, 'b, U, F, Backend>
43where
44	U: UnderlierType + PackScalar<F>,
45	F: TowerField,
46	Backend: ComputationBackend,
47{
48	pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
49	pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
50
51	#[getset(get = "pub", get_mut = "pub")]
52	committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
53
54	finalized_proofs: EvalPointOracleIdMap<(F, EvalcheckProof<F>), F>,
55
56	claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
57	incomplete_proof_claims: EvalPointOracleIdMap<EvalcheckMultilinearClaim<F>, F>,
58	#[allow(clippy::type_complexity)]
59	claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
60	claims_without_evals_dedup: EvalPointOracleIdMap<(), F>,
61	projected_bivariate_claims: Vec<EvalcheckMultilinearClaim<F>>,
62
63	new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
64	memoized_queries: MemoizedQueries<PackedType<U, F>, Backend>,
65	backend: &'a Backend,
66}
67
68impl<'a, 'b, U, F, Backend> EvalcheckProver<'a, 'b, U, F, Backend>
69where
70	U: UnderlierType + PackScalar<F>,
71	PackedType<U, F>: PackedFieldIndexable,
72	F: TowerField,
73	Backend: ComputationBackend,
74{
75	/// Create a new prover state by tying together the mutable references to the oracle set and
76	/// witness index (they need to be mutable because `new_sumcheck` reduction may add new oracles & multilinears)
77	/// as well as committed eval claims accumulator.
78	pub fn new(
79		oracles: &'a mut MultilinearOracleSet<F>,
80		witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
81		backend: &'a Backend,
82	) -> Self {
83		Self {
84			oracles,
85			witness_index,
86			committed_eval_claims: Vec::new(),
87			new_sumchecks_constraints: Vec::new(),
88			finalized_proofs: EvalPointOracleIdMap::new(),
89			claims_queue: Vec::new(),
90			claims_without_evals: Vec::new(),
91			claims_without_evals_dedup: EvalPointOracleIdMap::new(),
92			projected_bivariate_claims: Vec::new(),
93			memoized_queries: MemoizedQueries::new(),
94			backend,
95			incomplete_proof_claims: EvalPointOracleIdMap::new(),
96		}
97	}
98
99	/// A helper method to move out sumcheck constraints
100	pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
101		self.new_sumchecks_constraints
102			.iter_mut()
103			.map(|builder| std::mem::take(builder).build_one(self.oracles))
104			.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
105			.rev()
106			.collect()
107	}
108
109	/// Prove an evalcheck claim.
110	///
111	/// Given a prover state containing [`MultilinearOracleSet`] indexing into given
112	/// [`MultilinearExtensionIndex`], we prove an [`EvalcheckMultilinearClaim`] (stating that given composite
113	/// `poly` equals `eval` at `eval_point`) by recursively processing each of the multilinears.
114	/// This way the evalcheck claim gets transformed into an [`EvalcheckProof`]
115	/// and a new set of claims on:
116	///  * Committed polynomial evaluations
117	///  * New sumcheck constraints that need to be proven in subsequent rounds (those get appended to `new_sumchecks`)
118	///
119	/// All of the `new_sumchecks` constraints follow the same pattern:
120	///  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
121	///  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
122	///  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
123	#[instrument(skip_all, name = "EvalcheckProver::prove", level = "debug")]
124	pub fn prove(
125		&mut self,
126		evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
127	) -> Result<Vec<EvalcheckProof<F>>, Error> {
128		for claim in &evalcheck_claims {
129			self.claims_without_evals_dedup
130				.insert(claim.id, claim.eval_point.clone(), ());
131		}
132
133		// Step 1: Collect proofs
134		self.claims_queue.extend(evalcheck_claims.clone());
135
136		// Use modified BFS approach with memoization to collect proofs.
137		// The `prove_multilinear` function saves a proof if it can be generated immediately; otherwise, the claim is added to `incomplete_proof_claims` and resolved after BFS.
138		// Claims requiring additional evaluation are stored in `claims_without_evals` and processed in parallel.
139		while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
140			// Prove all available claims
141			while !self.claims_queue.is_empty() {
142				std::mem::take(&mut self.claims_queue)
143					.into_iter()
144					.for_each(|claim| self.prove_multilinear(claim));
145			}
146
147			let mut deduplicated_claims_without_evals = Vec::new();
148
149			for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
150				if self.finalized_proofs.get(poly.id(), &eval_point).is_some() {
151					continue;
152				}
153				if self
154					.claims_without_evals_dedup
155					.get(poly.id(), &eval_point)
156					.is_some()
157				{
158					continue;
159				}
160
161				self.claims_without_evals_dedup
162					.insert(poly.id(), eval_point.clone(), ());
163
164				deduplicated_claims_without_evals.push((poly, eval_point.clone()))
165			}
166
167			let deduplicated_eval_points = deduplicated_claims_without_evals
168				.iter()
169				.map(|(_, eval_point)| eval_point.as_ref())
170				.collect::<Vec<_>>();
171
172			self.memoized_queries
173				.memoize_query_par(&deduplicated_eval_points, self.backend)?;
174
175			// Make new evaluation claims in parallel.
176			let subclaims = deduplicated_claims_without_evals
177				.into_par_iter()
178				.map(|(poly, eval_point)| {
179					Self::make_new_eval_claim(
180						poly.id(),
181						eval_point,
182						self.witness_index,
183						&self.memoized_queries,
184					)
185				})
186				.collect::<Result<Vec<_>, Error>>()?;
187
188			subclaims
189				.into_iter()
190				.for_each(|claim| self.prove_multilinear(claim));
191		}
192
193		let mut incomplete_proof_claims =
194			std::mem::take(&mut self.incomplete_proof_claims).flatten();
195
196		while !incomplete_proof_claims.is_empty() {
197			for claim in std::mem::take(&mut incomplete_proof_claims) {
198				if self.complete_proof(&claim) {
199					continue;
200				}
201				incomplete_proof_claims.push(claim);
202			}
203		}
204
205		// Step 2: Collect batch_committed_eval_claims and projected_bivariate_claims in right order
206
207		// Since we use BFS for collecting proofs and DFS for verifying them,
208		// it imposes restrictions on the correct order of collecting `batch_committed_eval_claims` and `projected_bivariate_claims`.
209		// Therefore, we run a DFS to handle this.
210		evalcheck_claims
211			.iter()
212			.cloned()
213			.for_each(|claim| self.collect_projected_committed(claim));
214
215		// Step 3: Process projected_bivariate_claims
216
217		let projected_bivariate_metas = self
218			.projected_bivariate_claims
219			.iter()
220			.map(|claim| Self::projected_bivariate_meta(self.oracles, claim))
221			.collect::<Result<Vec<_>, Error>>()?;
222
223		let projected_mles = calculate_projected_mles(
224			&projected_bivariate_metas,
225			&mut self.memoized_queries,
226			&self.projected_bivariate_claims,
227			self.witness_index,
228			self.backend,
229		)?;
230
231		fill_eq_witness_for_composites(
232			&projected_bivariate_metas,
233			&mut self.memoized_queries,
234			&self.projected_bivariate_claims,
235			self.witness_index,
236			self.backend,
237		)?;
238
239		for (claim, meta, projected) in izip!(
240			std::mem::take(&mut self.projected_bivariate_claims),
241			projected_bivariate_metas,
242			projected_mles
243		) {
244			self.process_sumcheck(claim, meta, projected)?;
245		}
246
247		// Step 4: Find and return the proofs of the original claims.
248
249		Ok(evalcheck_claims
250			.iter()
251			.map(|claim| {
252				self.finalized_proofs
253					.get(claim.id, &claim.eval_point)
254					.map(|(_, proof)| proof.clone())
255					.expect("finalized_proofs contains all the proofs")
256			})
257			.collect::<Vec<_>>())
258	}
259
260	#[instrument(
261		skip_all,
262		name = "EvalcheckProverState::prove_multilinear",
263		level = "debug"
264	)]
265	fn prove_multilinear(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
266		let multilinear_id = evalcheck_claim.id;
267
268		let eval_point = evalcheck_claim.eval_point.clone();
269
270		let eval = evalcheck_claim.eval;
271
272		if self
273			.finalized_proofs
274			.get(multilinear_id, &eval_point)
275			.is_some()
276		{
277			return;
278		}
279
280		if self
281			.incomplete_proof_claims
282			.get(multilinear_id, &eval_point)
283			.is_some()
284		{
285			return;
286		}
287
288		let multilinear = self.oracles.oracle(multilinear_id);
289
290		match multilinear.variant {
291			MultilinearPolyVariant::Transparent { .. } => {
292				self.finalized_proofs.insert(
293					multilinear_id,
294					eval_point,
295					(eval, EvalcheckProof::Transparent),
296				);
297			}
298
299			MultilinearPolyVariant::Committed => {
300				self.finalized_proofs.insert(
301					multilinear_id,
302					eval_point,
303					(eval, EvalcheckProof::Committed),
304				);
305			}
306
307			MultilinearPolyVariant::Repeating { id, .. } => {
308				let n_vars = self.oracles.n_vars(id);
309				let inner_eval_point = eval_point.slice(0..n_vars);
310				let subclaim = EvalcheckMultilinearClaim {
311					id,
312					eval_point: inner_eval_point,
313					eval,
314				};
315				self.incomplete_proof_claims
316					.insert(multilinear_id, eval_point, evalcheck_claim);
317				self.claims_queue.push(subclaim);
318			}
319
320			MultilinearPolyVariant::Shifted { .. } => {
321				self.finalized_proofs.insert(
322					multilinear_id,
323					eval_point,
324					(eval, EvalcheckProof::Shifted),
325				);
326			}
327
328			MultilinearPolyVariant::Packed { .. } => {
329				self.finalized_proofs.insert(
330					multilinear_id,
331					eval_point,
332					(eval, EvalcheckProof::Packed),
333				);
334			}
335
336			MultilinearPolyVariant::Composite(_) => {
337				self.finalized_proofs.insert(
338					multilinear_id,
339					eval_point,
340					(eval, EvalcheckProof::CompositeMLE),
341				);
342			}
343
344			MultilinearPolyVariant::Projected(projected) => {
345				let (id, values) = (projected.id(), projected.values());
346				let new_eval_point = match projected.projection_variant() {
347					ProjectionVariant::LastVars => {
348						let mut new_eval_point = eval_point.to_vec();
349						new_eval_point.extend(values);
350						new_eval_point
351					}
352					ProjectionVariant::FirstVars => {
353						values.iter().copied().chain(eval_point.to_vec()).collect()
354					}
355				};
356
357				let subclaim = EvalcheckMultilinearClaim {
358					id,
359					eval_point: new_eval_point.into(),
360					eval,
361				};
362				self.incomplete_proof_claims
363					.insert(multilinear_id, eval_point, evalcheck_claim);
364				self.claims_queue.push(subclaim);
365			}
366
367			MultilinearPolyVariant::LinearCombination(linear_combination) => {
368				let n_polys = linear_combination.n_polys();
369
370				match linear_combination
371					.polys()
372					.zip(linear_combination.coefficients())
373					.next()
374				{
375					Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
376						let eval = (eval - linear_combination.offset())
377							* coeff.invert().expect("not zero");
378						let subclaim = EvalcheckMultilinearClaim {
379							id: suboracle_id,
380							eval_point: eval_point.clone(),
381							eval,
382						};
383						self.claims_queue.push(subclaim);
384					}
385					_ => {
386						for suboracle_id in linear_combination.polys() {
387							self.claims_without_evals
388								.push((self.oracles.oracle(suboracle_id), eval_point.clone()));
389						}
390					}
391				};
392
393				self.incomplete_proof_claims
394					.insert(multilinear_id, eval_point, evalcheck_claim);
395			}
396
397			MultilinearPolyVariant::ZeroPadded(id) => {
398				let inner = self.oracles.oracle(id);
399				let inner_n_vars = inner.n_vars();
400				let inner_eval_point = eval_point.slice(0..inner_n_vars);
401				self.claims_without_evals.push((inner, inner_eval_point));
402				self.incomplete_proof_claims
403					.insert(multilinear_id, eval_point, evalcheck_claim);
404			}
405		};
406	}
407
408	fn complete_proof(&mut self, evalcheck_claim: &EvalcheckMultilinearClaim<F>) -> bool {
409		let id = &evalcheck_claim.id;
410		let eval_point = evalcheck_claim.eval_point.clone();
411		let eval = evalcheck_claim.eval;
412
413		let res = match self.oracles.oracle(*id).variant {
414			MultilinearPolyVariant::Repeating { id, .. } => {
415				let n_vars = self.oracles.n_vars(id);
416				let inner_eval_point = &evalcheck_claim.eval_point[..n_vars];
417				self.finalized_proofs
418					.get(id, inner_eval_point)
419					.map(|(_, subproof)| subproof.clone())
420					.map(move |subproof| {
421						let proof = EvalcheckProof::Repeating(Box::new(subproof));
422						self.finalized_proofs
423							.insert(evalcheck_claim.id, eval_point, (eval, proof));
424					})
425			}
426			MultilinearPolyVariant::Projected(projected) => {
427				let (id, values) = (projected.id(), projected.values());
428				let new_eval_point = match projected.projection_variant() {
429					ProjectionVariant::LastVars => {
430						let mut new_eval_point = eval_point.to_vec();
431						new_eval_point.extend(values);
432						new_eval_point
433					}
434					ProjectionVariant::FirstVars => values
435						.iter()
436						.copied()
437						.chain((*eval_point).to_vec())
438						.collect(),
439				};
440				self.finalized_proofs
441					.get(id, &new_eval_point)
442					.map(|(_, subproof)| subproof.clone())
443					.map(|subproof| {
444						self.finalized_proofs.insert(
445							evalcheck_claim.id,
446							eval_point,
447							(eval, subproof),
448						);
449					})
450			}
451
452			MultilinearPolyVariant::LinearCombination(linear_combination) => linear_combination
453				.polys()
454				.map(|suboracle_id| {
455					self.finalized_proofs
456						.get(suboracle_id, &evalcheck_claim.eval_point)
457						.map(|(eval, subproof)| (*eval, subproof.clone()))
458				})
459				.collect::<Option<Vec<_>>>()
460				.map(|subproofs| {
461					self.finalized_proofs.insert(
462						evalcheck_claim.id,
463						eval_point,
464						(eval, EvalcheckProof::LinearCombination { subproofs }),
465					);
466				}),
467
468			MultilinearPolyVariant::ZeroPadded(inner_id) => {
469				let inner_n_vars = self.oracles.n_vars(inner_id);
470				let inner_eval_point = &evalcheck_claim.eval_point[..inner_n_vars];
471				self.finalized_proofs
472					.get(inner_id, inner_eval_point)
473					.map(|(eval, subproof)| (*eval, subproof.clone()))
474					.map(|(internal_eval, subproof)| {
475						self.finalized_proofs.insert(
476							evalcheck_claim.id,
477							eval_point,
478							(eval, EvalcheckProof::ZeroPadded(internal_eval, Box::new(subproof))),
479						);
480					})
481			}
482
483			_ => unreachable!(),
484		};
485		res.is_some()
486	}
487
488	fn collect_projected_committed(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
489		let EvalcheckMultilinearClaim {
490			id,
491			eval_point,
492			eval,
493		} = evalcheck_claim.clone();
494
495		let multilinear = self.oracles.oracle(id);
496		match multilinear.variant {
497			MultilinearPolyVariant::Committed => {
498				let subclaim = EvalcheckMultilinearClaim {
499					id: multilinear.id,
500					eval_point,
501					eval,
502				};
503
504				self.committed_eval_claims.push(subclaim);
505			}
506			MultilinearPolyVariant::Repeating { id, .. } => {
507				let n_vars = self.oracles.n_vars(id);
508				let inner_eval_point = eval_point.slice(0..n_vars);
509				let subclaim = EvalcheckMultilinearClaim {
510					id,
511					eval_point: inner_eval_point,
512					eval,
513				};
514
515				self.collect_projected_committed(subclaim);
516			}
517			MultilinearPolyVariant::Projected(projected) => {
518				let (id, values) = (projected.id(), projected.values());
519				let new_eval_point = match projected.projection_variant() {
520					ProjectionVariant::LastVars => {
521						let mut new_eval_point = eval_point.to_vec();
522						new_eval_point.extend(values);
523						new_eval_point
524					}
525					ProjectionVariant::FirstVars => {
526						values.iter().copied().chain(eval_point.to_vec()).collect()
527					}
528				};
529
530				let subclaim = EvalcheckMultilinearClaim {
531					id,
532					eval_point: new_eval_point.into(),
533					eval,
534				};
535				self.collect_projected_committed(subclaim);
536			}
537			MultilinearPolyVariant::Shifted { .. }
538			| MultilinearPolyVariant::Packed { .. }
539			| MultilinearPolyVariant::Composite { .. } => {
540				self.projected_bivariate_claims.push(evalcheck_claim)
541			}
542			MultilinearPolyVariant::LinearCombination(linear_combination) => {
543				for id in linear_combination.polys() {
544					let (eval, _) = self
545						.finalized_proofs
546						.get(id, &eval_point)
547						.expect("finalized_proofs contains all the proofs");
548					let subclaim = EvalcheckMultilinearClaim {
549						id,
550						eval_point: eval_point.clone(),
551						eval: *eval,
552					};
553					self.collect_projected_committed(subclaim);
554				}
555			}
556			MultilinearPolyVariant::ZeroPadded(id) => {
557				let inner_n_vars = self.oracles.n_vars(id);
558				let inner_eval_point = eval_point.slice(0..inner_n_vars);
559
560				let (eval, _) = self
561					.finalized_proofs
562					.get(id, &inner_eval_point)
563					.expect("finalized_proofs contains all the proofs");
564
565				let subclaim = EvalcheckMultilinearClaim {
566					id,
567					eval_point,
568					eval: *eval,
569				};
570				self.collect_projected_committed(subclaim);
571			}
572			_ => {}
573		}
574	}
575
576	fn projected_bivariate_meta(
577		oracles: &mut MultilinearOracleSet<F>,
578		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
579	) -> Result<ProjectedBivariateMeta, Error> {
580		let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
581
582		match &oracles.oracle(*id).variant {
583			MultilinearPolyVariant::Shifted(shifted) => {
584				shifted_sumcheck_meta(oracles, shifted, eval_point)
585			}
586			MultilinearPolyVariant::Packed(packed) => {
587				packed_sumcheck_meta(oracles, packed, eval_point)
588			}
589			MultilinearPolyVariant::Composite(_) => composite_sumcheck_meta(oracles, eval_point),
590			_ => unreachable!(),
591		}
592	}
593
594	fn process_sumcheck(
595		&mut self,
596		evalcheck_claim: EvalcheckMultilinearClaim<F>,
597		meta: ProjectedBivariateMeta,
598		projected: Option<MultilinearExtension<PackedType<U, F>>>,
599	) -> Result<(), Error> {
600		let EvalcheckMultilinearClaim {
601			id,
602			eval_point,
603			eval,
604		} = evalcheck_claim;
605
606		match self.oracles.oracle(id).variant {
607			MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
608				&shifted,
609				meta,
610				&eval_point,
611				eval,
612				self.witness_index,
613				&mut self.new_sumchecks_constraints,
614				projected.expect("projected is required by shifted oracle"),
615			),
616
617			MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
618				self.oracles,
619				&packed,
620				meta,
621				&eval_point,
622				eval,
623				self.witness_index,
624				&mut self.new_sumchecks_constraints,
625				projected.expect("projected is required by packed oracle"),
626			),
627
628			MultilinearPolyVariant::Composite(composite) => {
629				// witness for eq MLE has been previously filled in `fill_eq_witness_for_composites`
630				add_composite_sumcheck_to_constraints(
631					meta,
632					&mut self.new_sumchecks_constraints,
633					&composite,
634					eval,
635				);
636				Ok(())
637			}
638			_ => unreachable!(),
639		}
640	}
641
642	fn make_new_eval_claim(
643		oracle_id: OracleId,
644		eval_point: EvalPoint<F>,
645		witness_index: &MultilinearExtensionIndex<U, F>,
646		memoized_queries: &MemoizedQueries<PackedType<U, F>, Backend>,
647	) -> Result<EvalcheckMultilinearClaim<F>, Error> {
648		let eval_query = memoized_queries
649			.full_query_readonly(&eval_point)
650			.ok_or(Error::MissingQuery)?;
651
652		let witness_poly = witness_index
653			.get_multilin_poly(oracle_id)
654			.map_err(Error::Witness)?;
655
656		let eval = witness_poly
657			.evaluate(eval_query.to_ref())
658			.map_err(Error::from)?;
659
660		Ok(EvalcheckMultilinearClaim {
661			id: oracle_id,
662			eval_point,
663			eval,
664		})
665	}
666}