binius_core/protocols/evalcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{PackedFieldIndexable, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_maybe_rayon::prelude::*;
7use getset::{Getters, MutGetters};
8use itertools::izip;
9use tracing::instrument;
10
11use super::{
12	error::Error,
13	evalcheck::{EvalcheckMultilinearClaim, EvalcheckProof},
14	subclaims::{
15		add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_sumcheck_meta,
16		fill_eq_witness_for_composites, MemoizedData, ProjectedBivariateMeta,
17	},
18	EvalPoint, EvalPointOracleIdMap,
19};
20use crate::{
21	oracle::{
22		ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
23		MultilinearPolyOracle, MultilinearPolyVariant, OracleId,
24	},
25	protocols::evalcheck::subclaims::{
26		packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
27		shifted_sumcheck_meta,
28	},
29	witness::MultilinearExtensionIndex,
30};
31
32/// A mutable prover state.
33///
34/// Can be persisted across [`EvalcheckProver::prove`] invocations. Accumulates
35/// `new_sumchecks` bivariate sumcheck instances, as well as holds mutable references to
36/// the trace (to which new oracles & multilinears may be added during proving)
37#[derive(Getters, MutGetters)]
38pub struct EvalcheckProver<'a, 'b, F, P, Backend>
39where
40	P: PackedFieldIndexable<Scalar = F>,
41	F: TowerField,
42	Backend: ComputationBackend,
43{
44	pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
45	pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
46
47	#[getset(get = "pub", get_mut = "pub")]
48	committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
49
50	finalized_proofs: EvalPointOracleIdMap<(F, EvalcheckProof<F>), F>,
51
52	claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
53	incomplete_proof_claims: EvalPointOracleIdMap<EvalcheckMultilinearClaim<F>, F>,
54	claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
55	claims_without_evals_dedup: EvalPointOracleIdMap<(), F>,
56	projected_bivariate_claims: Vec<EvalcheckMultilinearClaim<F>>,
57
58	new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
59	pub memoized_data: MemoizedData<'b, P, Backend>,
60	backend: &'a Backend,
61}
62
63impl<'a, 'b, F, P, Backend> EvalcheckProver<'a, 'b, F, P, Backend>
64where
65	P: PackedFieldIndexable<Scalar = F>,
66	F: TowerField,
67	Backend: ComputationBackend,
68{
69	/// Create a new prover state by tying together the mutable references to the oracle set and
70	/// witness index (they need to be mutable because `new_sumcheck` reduction may add new oracles & multilinears)
71	/// as well as committed eval claims accumulator.
72	pub fn new(
73		oracles: &'a mut MultilinearOracleSet<F>,
74		witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
75		backend: &'a Backend,
76	) -> Self {
77		Self {
78			oracles,
79			witness_index,
80			committed_eval_claims: Vec::new(),
81			new_sumchecks_constraints: Vec::new(),
82			finalized_proofs: EvalPointOracleIdMap::new(),
83			claims_queue: Vec::new(),
84			claims_without_evals: Vec::new(),
85			claims_without_evals_dedup: EvalPointOracleIdMap::new(),
86			projected_bivariate_claims: Vec::new(),
87			memoized_data: MemoizedData::new(),
88			backend,
89			incomplete_proof_claims: EvalPointOracleIdMap::new(),
90		}
91	}
92
93	/// A helper method to move out sumcheck constraints
94	pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
95		self.new_sumchecks_constraints
96			.iter_mut()
97			.map(|builder| std::mem::take(builder).build_one(self.oracles))
98			.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
99			.rev()
100			.collect()
101	}
102
103	/// Prove an evalcheck claim.
104	///
105	/// Given a prover state containing [`MultilinearOracleSet`] indexing into given
106	/// [`MultilinearExtensionIndex`], we prove an [`EvalcheckMultilinearClaim`] (stating that given composite
107	/// `poly` equals `eval` at `eval_point`) by recursively processing each of the multilinears.
108	/// This way the evalcheck claim gets transformed into an [`EvalcheckProof`]
109	/// and a new set of claims on:
110	///  * Committed polynomial evaluations
111	///  * New sumcheck constraints that need to be proven in subsequent rounds (those get appended to `new_sumchecks`)
112	///
113	/// All of the `new_sumchecks` constraints follow the same pattern:
114	///  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
115	///  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
116	///  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
117	#[instrument(skip_all, name = "EvalcheckProver::prove", level = "debug")]
118	pub fn prove(
119		&mut self,
120		evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
121	) -> Result<Vec<EvalcheckProof<F>>, Error> {
122		for claim in &evalcheck_claims {
123			self.claims_without_evals_dedup
124				.insert(claim.id, claim.eval_point.clone(), ());
125		}
126
127		// Step 1: Collect proofs
128		self.claims_queue.extend(evalcheck_claims.clone());
129
130		// Use modified BFS approach with memoization to collect proofs.
131		// The `prove_multilinear` function saves a proof if it can be generated immediately; otherwise, the claim is added to `incomplete_proof_claims` and resolved after BFS.
132		// Claims requiring additional evaluation are stored in `claims_without_evals` and processed in parallel.
133		while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
134			// Prove all available claims
135			while !self.claims_queue.is_empty() {
136				std::mem::take(&mut self.claims_queue)
137					.into_iter()
138					.for_each(|claim| self.prove_multilinear(claim));
139			}
140
141			let mut deduplicated_claims_without_evals = Vec::new();
142
143			for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
144				if self.finalized_proofs.get(poly.id(), &eval_point).is_some() {
145					continue;
146				}
147				if self
148					.claims_without_evals_dedup
149					.get(poly.id(), &eval_point)
150					.is_some()
151				{
152					continue;
153				}
154
155				self.claims_without_evals_dedup
156					.insert(poly.id(), eval_point.clone(), ());
157
158				deduplicated_claims_without_evals.push((poly, eval_point.clone()))
159			}
160
161			let deduplicated_eval_points = deduplicated_claims_without_evals
162				.iter()
163				.map(|(_, eval_point)| eval_point.as_ref())
164				.collect::<Vec<_>>();
165
166			self.memoized_data
167				.memoize_query_par(&deduplicated_eval_points, self.backend)?;
168
169			// Make new evaluation claims in parallel.
170			let subclaims = deduplicated_claims_without_evals
171				.into_par_iter()
172				.map(|(poly, eval_point)| {
173					Self::make_new_eval_claim(
174						poly.id(),
175						eval_point,
176						self.witness_index,
177						&self.memoized_data,
178					)
179				})
180				.collect::<Result<Vec<_>, Error>>()?;
181
182			subclaims
183				.into_iter()
184				.for_each(|claim| self.prove_multilinear(claim));
185		}
186
187		let mut incomplete_proof_claims =
188			std::mem::take(&mut self.incomplete_proof_claims).flatten();
189
190		while !incomplete_proof_claims.is_empty() {
191			for claim in std::mem::take(&mut incomplete_proof_claims) {
192				if self.complete_proof(&claim) {
193					continue;
194				}
195				incomplete_proof_claims.push(claim);
196			}
197		}
198
199		// Step 2: Collect batch_committed_eval_claims and projected_bivariate_claims in right order
200
201		// Since we use BFS for collecting proofs and DFS for verifying them,
202		// it imposes restrictions on the correct order of collecting `batch_committed_eval_claims` and `projected_bivariate_claims`.
203		// Therefore, we run a DFS to handle this.
204		evalcheck_claims
205			.iter()
206			.cloned()
207			.for_each(|claim| self.collect_projected_committed(claim));
208
209		// Step 3: Process projected_bivariate_claims
210
211		let projected_bivariate_metas = self
212			.projected_bivariate_claims
213			.iter()
214			.map(|claim| Self::projected_bivariate_meta(self.oracles, claim))
215			.collect::<Result<Vec<_>, Error>>()?;
216
217		let projected_mles = calculate_projected_mles(
218			&projected_bivariate_metas,
219			&mut self.memoized_data,
220			&self.projected_bivariate_claims,
221			self.witness_index,
222			self.backend,
223		)?;
224
225		fill_eq_witness_for_composites(
226			&projected_bivariate_metas,
227			&mut self.memoized_data,
228			&self.projected_bivariate_claims,
229			self.witness_index,
230			self.backend,
231		)?;
232
233		for (claim, meta, projected) in izip!(
234			std::mem::take(&mut self.projected_bivariate_claims),
235			&projected_bivariate_metas,
236			projected_mles
237		) {
238			self.process_sumcheck(claim, meta, projected)?;
239		}
240
241		self.memoized_data.memoize_partial_evals(
242			&projected_bivariate_metas,
243			&self.projected_bivariate_claims,
244			self.oracles,
245			self.witness_index,
246		);
247
248		// Step 4: Find and return the proofs of the original claims.
249
250		Ok(evalcheck_claims
251			.iter()
252			.map(|claim| {
253				self.finalized_proofs
254					.get(claim.id, &claim.eval_point)
255					.map(|(_, proof)| proof.clone())
256					.expect("finalized_proofs contains all the proofs")
257			})
258			.collect::<Vec<_>>())
259	}
260
261	#[instrument(
262		skip_all,
263		name = "EvalcheckProverState::prove_multilinear",
264		level = "debug"
265	)]
266	fn prove_multilinear(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
267		let multilinear_id = evalcheck_claim.id;
268
269		let eval_point = evalcheck_claim.eval_point.clone();
270
271		let eval = evalcheck_claim.eval;
272
273		if self
274			.finalized_proofs
275			.get(multilinear_id, &eval_point)
276			.is_some()
277		{
278			return;
279		}
280
281		if self
282			.incomplete_proof_claims
283			.get(multilinear_id, &eval_point)
284			.is_some()
285		{
286			return;
287		}
288
289		let multilinear = self.oracles.oracle(multilinear_id);
290
291		match multilinear.variant {
292			MultilinearPolyVariant::Transparent { .. } => {
293				self.finalized_proofs.insert(
294					multilinear_id,
295					eval_point,
296					(eval, EvalcheckProof::Transparent),
297				);
298			}
299
300			MultilinearPolyVariant::Committed => {
301				self.finalized_proofs.insert(
302					multilinear_id,
303					eval_point,
304					(eval, EvalcheckProof::Committed),
305				);
306			}
307
308			MultilinearPolyVariant::Repeating { id, .. } => {
309				let n_vars = self.oracles.n_vars(id);
310				let inner_eval_point = eval_point.slice(0..n_vars);
311				let subclaim = EvalcheckMultilinearClaim {
312					id,
313					eval_point: inner_eval_point,
314					eval,
315				};
316				self.incomplete_proof_claims
317					.insert(multilinear_id, eval_point, evalcheck_claim);
318				self.claims_queue.push(subclaim);
319			}
320
321			MultilinearPolyVariant::Shifted { .. } => {
322				self.finalized_proofs.insert(
323					multilinear_id,
324					eval_point,
325					(eval, EvalcheckProof::Shifted),
326				);
327			}
328
329			MultilinearPolyVariant::Packed { .. } => {
330				self.finalized_proofs.insert(
331					multilinear_id,
332					eval_point,
333					(eval, EvalcheckProof::Packed),
334				);
335			}
336
337			MultilinearPolyVariant::Composite(_) => {
338				self.finalized_proofs.insert(
339					multilinear_id,
340					eval_point,
341					(eval, EvalcheckProof::CompositeMLE),
342				);
343			}
344
345			MultilinearPolyVariant::Projected(projected) => {
346				let (id, values) = (projected.id(), projected.values());
347				let new_eval_point = {
348					let idx = projected.start_index();
349					let mut new_eval_point = eval_point[0..idx].to_vec();
350					new_eval_point.extend(values.clone());
351					new_eval_point.extend(eval_point[idx..].to_vec());
352					new_eval_point
353				};
354
355				let subclaim = EvalcheckMultilinearClaim {
356					id,
357					eval_point: new_eval_point.into(),
358					eval,
359				};
360				self.incomplete_proof_claims
361					.insert(multilinear_id, eval_point, evalcheck_claim);
362				self.claims_queue.push(subclaim);
363			}
364
365			MultilinearPolyVariant::LinearCombination(linear_combination) => {
366				let n_polys = linear_combination.n_polys();
367
368				match linear_combination
369					.polys()
370					.zip(linear_combination.coefficients())
371					.next()
372				{
373					Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
374						let eval = (eval - linear_combination.offset())
375							* coeff.invert().expect("not zero");
376						let subclaim = EvalcheckMultilinearClaim {
377							id: suboracle_id,
378							eval_point: eval_point.clone(),
379							eval,
380						};
381						self.claims_queue.push(subclaim);
382					}
383					_ => {
384						for suboracle_id in linear_combination.polys() {
385							self.claims_without_evals
386								.push((self.oracles.oracle(suboracle_id), eval_point.clone()));
387						}
388					}
389				};
390
391				self.incomplete_proof_claims
392					.insert(multilinear_id, eval_point, evalcheck_claim);
393			}
394
395			MultilinearPolyVariant::ZeroPadded(id) => {
396				let inner = self.oracles.oracle(id);
397				let inner_n_vars = inner.n_vars();
398				let inner_eval_point = eval_point.slice(0..inner_n_vars);
399				self.claims_without_evals.push((inner, inner_eval_point));
400				self.incomplete_proof_claims
401					.insert(multilinear_id, eval_point, evalcheck_claim);
402			}
403		};
404	}
405
406	fn complete_proof(&mut self, evalcheck_claim: &EvalcheckMultilinearClaim<F>) -> bool {
407		let id = &evalcheck_claim.id;
408		let eval_point = evalcheck_claim.eval_point.clone();
409		let eval = evalcheck_claim.eval;
410
411		let res = match self.oracles.oracle(*id).variant {
412			MultilinearPolyVariant::Repeating { id, .. } => {
413				let n_vars = self.oracles.n_vars(id);
414				let inner_eval_point = &evalcheck_claim.eval_point[..n_vars];
415				self.finalized_proofs
416					.get(id, inner_eval_point)
417					.map(|(_, subproof)| subproof.clone())
418					.map(move |subproof| {
419						let proof = EvalcheckProof::Repeating(Box::new(subproof));
420						self.finalized_proofs
421							.insert(evalcheck_claim.id, eval_point, (eval, proof));
422					})
423			}
424			MultilinearPolyVariant::Projected(projected) => {
425				let (id, values) = (projected.id(), projected.values());
426				let new_eval_point = {
427					let idx = projected.start_index();
428					let mut new_eval_point = eval_point[0..idx].to_vec();
429					new_eval_point.extend(values.clone());
430					new_eval_point.extend(eval_point[idx..].to_vec());
431					new_eval_point
432				};
433
434				self.finalized_proofs
435					.get(id, &new_eval_point)
436					.map(|(_, subproof)| subproof.clone())
437					.map(|subproof| {
438						self.finalized_proofs.insert(
439							evalcheck_claim.id,
440							eval_point,
441							(eval, subproof),
442						);
443					})
444			}
445
446			MultilinearPolyVariant::LinearCombination(linear_combination) => linear_combination
447				.polys()
448				.map(|suboracle_id| {
449					self.finalized_proofs
450						.get(suboracle_id, &evalcheck_claim.eval_point)
451						.map(|(eval, subproof)| (*eval, subproof.clone()))
452				})
453				.collect::<Option<Vec<_>>>()
454				.map(|subproofs| {
455					self.finalized_proofs.insert(
456						evalcheck_claim.id,
457						eval_point,
458						(eval, EvalcheckProof::LinearCombination { subproofs }),
459					);
460				}),
461
462			MultilinearPolyVariant::ZeroPadded(inner_id) => {
463				let inner_n_vars = self.oracles.n_vars(inner_id);
464				let inner_eval_point = &evalcheck_claim.eval_point[..inner_n_vars];
465				self.finalized_proofs
466					.get(inner_id, inner_eval_point)
467					.map(|(eval, subproof)| (*eval, subproof.clone()))
468					.map(|(internal_eval, subproof)| {
469						self.finalized_proofs.insert(
470							evalcheck_claim.id,
471							eval_point,
472							(eval, EvalcheckProof::ZeroPadded(internal_eval, Box::new(subproof))),
473						);
474					})
475			}
476
477			_ => unreachable!(),
478		};
479		res.is_some()
480	}
481
482	fn collect_projected_committed(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
483		let EvalcheckMultilinearClaim {
484			id,
485			eval_point,
486			eval,
487		} = evalcheck_claim.clone();
488
489		let multilinear = self.oracles.oracle(id);
490		match multilinear.variant {
491			MultilinearPolyVariant::Committed => {
492				let subclaim = EvalcheckMultilinearClaim {
493					id: multilinear.id,
494					eval_point,
495					eval,
496				};
497
498				self.committed_eval_claims.push(subclaim);
499			}
500			MultilinearPolyVariant::Repeating { id, .. } => {
501				let n_vars = self.oracles.n_vars(id);
502				let inner_eval_point = eval_point.slice(0..n_vars);
503				let subclaim = EvalcheckMultilinearClaim {
504					id,
505					eval_point: inner_eval_point,
506					eval,
507				};
508
509				self.collect_projected_committed(subclaim);
510			}
511			MultilinearPolyVariant::Projected(projected) => {
512				let (id, values) = (projected.id(), projected.values());
513				let new_eval_point = {
514					let idx = projected.start_index();
515					let mut new_eval_point = eval_point[0..idx].to_vec();
516					new_eval_point.extend(values.clone());
517					new_eval_point.extend(eval_point[idx..].to_vec());
518					new_eval_point
519				};
520
521				let subclaim = EvalcheckMultilinearClaim {
522					id,
523					eval_point: new_eval_point.into(),
524					eval,
525				};
526				self.collect_projected_committed(subclaim);
527			}
528			MultilinearPolyVariant::Shifted { .. }
529			| MultilinearPolyVariant::Packed { .. }
530			| MultilinearPolyVariant::Composite { .. } => {
531				self.projected_bivariate_claims.push(evalcheck_claim)
532			}
533			MultilinearPolyVariant::LinearCombination(linear_combination) => {
534				for id in linear_combination.polys() {
535					let (eval, _) = self
536						.finalized_proofs
537						.get(id, &eval_point)
538						.expect("finalized_proofs contains all the proofs");
539					let subclaim = EvalcheckMultilinearClaim {
540						id,
541						eval_point: eval_point.clone(),
542						eval: *eval,
543					};
544					self.collect_projected_committed(subclaim);
545				}
546			}
547			MultilinearPolyVariant::ZeroPadded(id) => {
548				let inner_n_vars = self.oracles.n_vars(id);
549				let inner_eval_point = eval_point.slice(0..inner_n_vars);
550
551				let (eval, _) = self
552					.finalized_proofs
553					.get(id, &inner_eval_point)
554					.expect("finalized_proofs contains all the proofs");
555
556				let subclaim = EvalcheckMultilinearClaim {
557					id,
558					eval_point,
559					eval: *eval,
560				};
561				self.collect_projected_committed(subclaim);
562			}
563			_ => {}
564		}
565	}
566
567	fn projected_bivariate_meta(
568		oracles: &mut MultilinearOracleSet<F>,
569		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
570	) -> Result<ProjectedBivariateMeta, Error> {
571		let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
572
573		match &oracles.oracle(*id).variant {
574			MultilinearPolyVariant::Shifted(shifted) => {
575				shifted_sumcheck_meta(oracles, shifted, eval_point)
576			}
577			MultilinearPolyVariant::Packed(packed) => {
578				packed_sumcheck_meta(oracles, packed, eval_point)
579			}
580			MultilinearPolyVariant::Composite(_) => composite_sumcheck_meta(oracles, eval_point),
581			_ => unreachable!(),
582		}
583	}
584
585	fn process_sumcheck(
586		&mut self,
587		evalcheck_claim: EvalcheckMultilinearClaim<F>,
588		meta: &ProjectedBivariateMeta,
589		projected: Option<MultilinearExtension<P>>,
590	) -> Result<(), Error> {
591		let EvalcheckMultilinearClaim {
592			id,
593			eval_point,
594			eval,
595		} = evalcheck_claim;
596
597		match self.oracles.oracle(id).variant {
598			MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
599				&shifted,
600				meta,
601				&eval_point,
602				eval,
603				self.witness_index,
604				&mut self.new_sumchecks_constraints,
605				projected,
606			),
607
608			MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
609				self.oracles,
610				&packed,
611				meta,
612				&eval_point,
613				eval,
614				self.witness_index,
615				&mut self.new_sumchecks_constraints,
616				projected,
617			),
618
619			MultilinearPolyVariant::Composite(composite) => {
620				// witness for eq MLE has been previously filled in `fill_eq_witness_for_composites`
621				add_composite_sumcheck_to_constraints(
622					meta,
623					&mut self.new_sumchecks_constraints,
624					&composite,
625					eval,
626				);
627				Ok(())
628			}
629			_ => unreachable!(),
630		}
631	}
632
633	fn make_new_eval_claim(
634		oracle_id: OracleId,
635		eval_point: EvalPoint<F>,
636		witness_index: &MultilinearExtensionIndex<P>,
637		memoized_queries: &MemoizedData<P, Backend>,
638	) -> Result<EvalcheckMultilinearClaim<F>, Error> {
639		let eval_query = memoized_queries
640			.full_query_readonly(&eval_point)
641			.ok_or(Error::MissingQuery)?;
642
643		let witness_poly = witness_index
644			.get_multilin_poly(oracle_id)
645			.map_err(Error::Witness)?;
646
647		let eval = witness_poly
648			.evaluate(eval_query.to_ref())
649			.map_err(Error::from)?;
650
651		Ok(EvalcheckMultilinearClaim {
652			id: oracle_id,
653			eval_point,
654			eval,
655		})
656	}
657}