binius_core/protocols/evalcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::collections::HashSet;
4
5use binius_field::{PackedField, TowerField};
6use binius_hal::ComputationBackend;
7use binius_math::MultilinearExtension;
8use binius_maybe_rayon::prelude::*;
9use getset::{Getters, MutGetters};
10use itertools::chain;
11use tracing::instrument;
12
13use super::{
14	error::Error,
15	evalcheck::{EvalcheckHint, EvalcheckMultilinearClaim},
16	serialize_evalcheck_proof,
17	subclaims::{
18		add_composite_sumcheck_to_constraints, calculate_projected_mles, composite_mlecheck_meta,
19		fill_eq_witness_for_composites, MemoizedData, ProjectedBivariateMeta, SumcheckClaims,
20	},
21	EvalPoint, EvalPointOracleIdMap,
22};
23use crate::{
24	fiat_shamir::Challenger,
25	oracle::{
26		ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
27		MultilinearPolyOracle, MultilinearPolyVariant, OracleId,
28	},
29	protocols::evalcheck::{
30		logging::MLEFoldHighDimensionsData,
31		subclaims::{
32			packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
33			shifted_sumcheck_meta, CompositeMLECheckMeta,
34		},
35	},
36	transcript::ProverTranscript,
37	witness::MultilinearExtensionIndex,
38};
39
40/// A mutable prover state.
41///
42/// Can be persisted across [`EvalcheckProver::prove`] invocations. Accumulates
43/// `new_sumchecks` bivariate sumcheck instances, as well as holds mutable references to
44/// the trace (to which new oracles & multilinears may be added during proving)
45#[derive(Getters, MutGetters)]
46pub struct EvalcheckProver<'a, 'b, F, P, Backend>
47where
48	P: PackedField<Scalar = F>,
49	F: TowerField,
50	Backend: ComputationBackend,
51{
52	/// Mutable reference to the oracle set which is modified to create new claims arising from sumchecks
53	pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
54	/// Mutable reference to the witness index which is is populated by the prover for new claims arising from sumchecks
55	pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
56
57	/// The committed evaluation claims arising in this round
58	#[getset(get = "pub", get_mut = "pub")]
59	committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
60
61	// Internally used to collect subclaims with evaluations to consume and further reduce.
62	claims_queue: Vec<EvalcheckMultilinearClaim<F>>,
63	// Internally used to collect subclaims without evaluations for future query and memoization
64	claims_without_evals: Vec<(MultilinearPolyOracle<F>, EvalPoint<F>)>,
65	// The list of claims that reduces to a bivariate sumcheck in a round.
66	sumcheck_claims: Vec<SumcheckClaims<P::Scalar>>,
67
68	// The new sumcheck constraints arising in this round
69	new_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
70	// Tensor expansion of evaluation points and partial evaluations of multilinears
71	pub memoized_data: MemoizedData<'b, P, Backend>,
72	backend: &'a Backend,
73
74	// The unique index of a claim in this round.
75	claim_to_index: EvalPointOracleIdMap<usize, F>,
76	// Claims that have been visited in this round, used to deduplicate claims when collecting subclaims in a BFS manner.
77	visited_claims: EvalPointOracleIdMap<(), F>,
78	// Memoization of evaluations of claims the prover sees in this round
79	evals_memoization: EvalPointOracleIdMap<F, F>,
80	// The index of the next claim to be verified
81	round_claim_index: usize,
82}
83
84impl<'a, 'b, F, P, Backend> EvalcheckProver<'a, 'b, F, P, Backend>
85where
86	P: PackedField<Scalar = F>,
87	F: TowerField,
88	Backend: ComputationBackend,
89{
90	/// Create a new prover state by tying together the mutable references to the oracle set and
91	/// witness index (they need to be mutable because `new_sumcheck` reduction may add new oracles & multilinears)
92	/// as well as committed eval claims accumulator.
93	pub fn new(
94		oracles: &'a mut MultilinearOracleSet<F>,
95		witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
96		backend: &'a Backend,
97	) -> Self {
98		Self {
99			oracles,
100			witness_index,
101			committed_eval_claims: Vec::new(),
102			new_sumchecks_constraints: Vec::new(),
103			claims_queue: Vec::new(),
104			claims_without_evals: Vec::new(),
105			sumcheck_claims: Vec::new(),
106			memoized_data: MemoizedData::new(),
107			backend,
108
109			claim_to_index: EvalPointOracleIdMap::new(),
110			visited_claims: EvalPointOracleIdMap::new(),
111			evals_memoization: EvalPointOracleIdMap::new(),
112			round_claim_index: 0,
113		}
114	}
115
116	/// A helper method to move out sumcheck constraints
117	pub fn take_new_sumchecks_constraints(&mut self) -> Result<Vec<ConstraintSet<F>>, OracleError> {
118		self.new_sumchecks_constraints
119			.iter_mut()
120			.map(|builder| std::mem::take(builder).build_one(self.oracles))
121			.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
122			.collect()
123	}
124
125	/// Prove an evalcheck claim.
126	///
127	/// Given a prover state containing [`MultilinearOracleSet`] indexing into given
128	/// [`MultilinearExtensionIndex`], we prove an [`EvalcheckMultilinearClaim`] (stating that given composite
129	/// `poly` equals `eval` at `eval_point`) by recursively processing each of the multilinears.
130	/// This way the evalcheck claim gets transformed into an [`EvalcheckHint`]
131	/// and a new set of claims on:
132	///  * Committed polynomial evaluations
133	///  * New sumcheck constraints that need to be proven in subsequent rounds (those get appended to `new_sumchecks`)
134	///
135	/// All of the `new_sumchecks` constraints follow the same pattern:
136	///  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
137	///  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
138	///  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
139	pub fn prove<Challenger_: Challenger>(
140		&mut self,
141		evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
142		transcript: &mut ProverTranscript<Challenger_>,
143	) -> Result<(), Error> {
144		// Reset the prover state for a new round.
145		self.round_claim_index = 0;
146		self.visited_claims.clear();
147		self.claim_to_index.clear();
148		self.evals_memoization.clear();
149
150		for claim in &evalcheck_claims {
151			if self
152				.evals_memoization
153				.get(claim.id, &claim.eval_point)
154				.is_some()
155			{
156				continue;
157			}
158
159			self.evals_memoization
160				.insert(claim.id, claim.eval_point.clone(), claim.eval);
161		}
162
163		self.claims_queue.extend(evalcheck_claims.clone());
164
165		// Step 1: Use modified BFS to memoize evaluations. For each claim, if there is a subclaim and we know the evaluation of the subclaim, we add the subclaim to the claims_queue
166		// Otherwise, we find the evaluation of the claim by querying the witness data from the oracle id and evaluation point
167		let mle_fold_full_span = tracing::debug_span!(
168			"[task] MLE Fold Full",
169			phase = "evalcheck",
170			perfetto_category = "task.main"
171		)
172		.entered();
173		while !self.claims_without_evals.is_empty() || !self.claims_queue.is_empty() {
174			while !self.claims_queue.is_empty() {
175				std::mem::take(&mut self.claims_queue)
176					.into_iter()
177					.for_each(|claim| self.collect_subclaims_for_memoization(claim));
178			}
179
180			let mut deduplicated_claims_without_evals = HashSet::new();
181
182			for (poly, eval_point) in std::mem::take(&mut self.claims_without_evals) {
183				if self.evals_memoization.get(poly.id(), &eval_point).is_some() {
184					continue;
185				}
186
187				deduplicated_claims_without_evals.insert((poly.id(), eval_point.clone()));
188			}
189
190			let deduplicated_eval_points = deduplicated_claims_without_evals
191				.iter()
192				.map(|(_, eval_point)| eval_point.as_ref())
193				.collect::<Vec<_>>();
194
195			// Tensor expansion of unique eval points.
196			self.memoized_data
197				.memoize_query_par(deduplicated_eval_points.iter().copied(), self.backend)?;
198
199			// Query and fill missing evaluations.
200			let subclaims = deduplicated_claims_without_evals
201				.into_par_iter()
202				.map(|(id, eval_point)| {
203					Self::make_new_eval_claim(
204						id,
205						eval_point,
206						self.witness_index,
207						&self.memoized_data,
208					)
209				})
210				.collect::<Result<Vec<_>, Error>>()?;
211
212			for subclaim in &subclaims {
213				self.evals_memoization.insert(
214					subclaim.id,
215					subclaim.eval_point.clone(),
216					subclaim.eval,
217				);
218			}
219
220			subclaims
221				.into_iter()
222				.for_each(|claim| self.collect_subclaims_for_memoization(claim));
223		}
224		drop(mle_fold_full_span);
225
226		// Step 2: Prove multilinears: For each claim, we prove the claim by recursively proving the subclaims by stepping through subclaims in a DFS manner
227		// and deduplicating claims.
228		for claim in evalcheck_claims {
229			self.prove_multilinear(claim, transcript)?;
230		}
231
232		// Step 3: Process projected_bivariate_claims
233		let mut projected_bivariate_metas = Vec::new();
234		let mut composite_mle_metas = Vec::new();
235		let mut projected_bivariate_claims = Vec::new();
236		let mut composite_mle_claims = Vec::new();
237
238		for claim in &self.sumcheck_claims {
239			match claim {
240				SumcheckClaims::Projected(claim) => {
241					let meta = Self::projected_bivariate_meta(self.oracles, claim)?;
242					projected_bivariate_metas.push(meta);
243					projected_bivariate_claims.push(claim.clone())
244				}
245				SumcheckClaims::Composite(claim) => {
246					let meta = composite_mlecheck_meta(self.oracles, &claim.eval_point)?;
247					composite_mle_metas.push(meta);
248					composite_mle_claims.push(claim.clone())
249				}
250			}
251		}
252		let dimensions_data = MLEFoldHighDimensionsData::new(projected_bivariate_claims.len());
253		let evalcheck_mle_fold_high_span = tracing::debug_span!(
254			"[task] (Evalcheck) MLE Fold High",
255			phase = "evalcheck",
256			perfetto_category = "task.main",
257			dimensions_data = ?dimensions_data,
258		)
259		.entered();
260
261		let projected_mles = calculate_projected_mles(
262			&projected_bivariate_metas,
263			&mut self.memoized_data,
264			&projected_bivariate_claims,
265			self.witness_index,
266			self.backend,
267		)?;
268		drop(evalcheck_mle_fold_high_span);
269
270		fill_eq_witness_for_composites(
271			&composite_mle_metas,
272			&mut self.memoized_data,
273			&composite_mle_claims,
274			self.witness_index,
275			self.backend,
276		)?;
277
278		let mut projected_index = 0;
279		let mut composite_index = 0;
280
281		for claim in std::mem::take(&mut self.sumcheck_claims) {
282			match claim {
283				SumcheckClaims::Projected(claim) => {
284					let meta = &projected_bivariate_metas[projected_index];
285					let projected = projected_mles[projected_index].clone();
286					self.process_bivariate_sumcheck(&claim, meta, projected)?;
287					projected_index += 1;
288				}
289				SumcheckClaims::Composite(claim) => {
290					let meta = composite_mle_metas[composite_index];
291					self.process_composite_mlecheck(&claim, meta)?;
292					composite_index += 1;
293				}
294			}
295		}
296
297		self.memoized_data.memoize_partial_evals(
298			&projected_bivariate_metas,
299			&projected_bivariate_claims,
300			self.oracles,
301			self.witness_index,
302		);
303
304		Ok(())
305	}
306
307	#[instrument(
308		skip_all,
309		name = "EvalcheckProverState::collect_subclaims_for_precompute",
310		level = "debug"
311	)]
312	fn collect_subclaims_for_memoization(&mut self, evalcheck_claim: EvalcheckMultilinearClaim<F>) {
313		let multilinear_id = evalcheck_claim.id;
314
315		let eval_point = evalcheck_claim.eval_point;
316
317		let eval = evalcheck_claim.eval;
318
319		if self
320			.visited_claims
321			.get(multilinear_id, &eval_point)
322			.is_some()
323		{
324			return;
325		}
326
327		self.visited_claims
328			.insert(multilinear_id, eval_point.clone(), ());
329
330		let multilinear = self.oracles.oracle(multilinear_id);
331
332		match multilinear.variant {
333			MultilinearPolyVariant::Repeating { id, .. } => {
334				let n_vars = self.oracles.n_vars(id);
335				let inner_eval_point = eval_point.slice(0..n_vars);
336				let subclaim = EvalcheckMultilinearClaim {
337					id,
338					eval_point: inner_eval_point,
339					eval,
340				};
341				self.claims_queue.push(subclaim);
342			}
343
344			MultilinearPolyVariant::Projected(projected) => {
345				let (id, values) = (projected.id(), projected.values());
346				let new_eval_point = {
347					let idx = projected.start_index();
348					let mut new_eval_point = eval_point[0..idx].to_vec();
349					new_eval_point.extend(values.clone());
350					new_eval_point.extend(eval_point[idx..].to_vec());
351					new_eval_point
352				};
353
354				let subclaim = EvalcheckMultilinearClaim {
355					id,
356					eval_point: new_eval_point.into(),
357					eval,
358				};
359				self.claims_queue.push(subclaim);
360			}
361
362			MultilinearPolyVariant::LinearCombination(linear_combination) => {
363				let n_polys = linear_combination.n_polys();
364
365				match linear_combination
366					.polys()
367					.zip(linear_combination.coefficients())
368					.next()
369				{
370					Some((suboracle_id, coeff)) if n_polys == 1 && !coeff.is_zero() => {
371						let eval = if let Some(eval) =
372							self.evals_memoization.get(suboracle_id, &eval_point)
373						{
374							*eval
375						} else {
376							let eval = (eval - linear_combination.offset())
377								* coeff.invert().expect("not zero");
378							self.evals_memoization
379								.insert(suboracle_id, eval_point.clone(), eval);
380							eval
381						};
382
383						let subclaim = EvalcheckMultilinearClaim {
384							id: suboracle_id,
385							eval_point,
386							eval,
387						};
388						self.claims_queue.push(subclaim);
389					}
390					_ => {
391						for suboracle_id in linear_combination.polys() {
392							self.claims_without_evals
393								.push((self.oracles.oracle(suboracle_id), eval_point.clone()));
394						}
395					}
396				};
397			}
398
399			MultilinearPolyVariant::ZeroPadded(padded) => {
400				let id = padded.id();
401				let inner = self.oracles.oracle(id);
402				let inner_eval_point = chain!(
403					&eval_point[..padded.start_index()],
404					&eval_point[padded.start_index() + padded.n_pad_vars()..],
405				)
406				.copied()
407				.collect::<Vec<_>>();
408				self.claims_without_evals
409					.push((inner, inner_eval_point.into()));
410			}
411			_ => return,
412		};
413	}
414
415	#[instrument(
416		skip_all,
417		name = "EvalcheckProverState::prove_multilinear",
418		level = "debug"
419	)]
420	fn prove_multilinear<Challenger_: Challenger>(
421		&mut self,
422		evalcheck_claim: EvalcheckMultilinearClaim<F>,
423		transcript: &mut ProverTranscript<Challenger_>,
424	) -> Result<(), Error> {
425		let EvalcheckMultilinearClaim { id, eval_point, .. } = &evalcheck_claim;
426
427		let claim_id = self.claim_to_index.get(*id, eval_point);
428
429		if let Some(claim_id) = claim_id {
430			serialize_evalcheck_proof(
431				&mut transcript.message(),
432				&EvalcheckHint::DuplicateClaim(*claim_id as u32),
433			);
434			return Ok(());
435		}
436		serialize_evalcheck_proof(&mut transcript.message(), &EvalcheckHint::NewClaim);
437
438		self.prove_multilinear_skip_duplicate_check(evalcheck_claim, transcript)
439	}
440
441	fn prove_multilinear_skip_duplicate_check<Challenger_: Challenger>(
442		&mut self,
443		evalcheck_claim: EvalcheckMultilinearClaim<F>,
444		transcript: &mut ProverTranscript<Challenger_>,
445	) -> Result<(), Error> {
446		let EvalcheckMultilinearClaim {
447			id,
448			eval_point,
449			eval,
450		} = evalcheck_claim;
451
452		self.claim_to_index
453			.insert(id, eval_point.clone(), self.round_claim_index);
454
455		self.round_claim_index += 1;
456
457		let multilinear = self.oracles.oracle(id);
458
459		match multilinear.variant {
460			MultilinearPolyVariant::Transparent { .. } => {}
461			MultilinearPolyVariant::Committed => {
462				self.committed_eval_claims.push(EvalcheckMultilinearClaim {
463					id: multilinear.id,
464					eval_point,
465					eval,
466				});
467			}
468			MultilinearPolyVariant::Repeating {
469				id: inner_id,
470				log_count,
471			} => {
472				let n_vars = eval_point.len() - log_count;
473				self.prove_multilinear(
474					EvalcheckMultilinearClaim {
475						id: inner_id,
476						eval_point: eval_point.slice(0..n_vars),
477						eval,
478					},
479					transcript,
480				)?;
481			}
482			MultilinearPolyVariant::Projected(projected) => {
483				let new_eval_point = {
484					let (lo, hi) = eval_point.split_at(projected.start_index());
485					chain!(lo, projected.values(), hi)
486						.copied()
487						.collect::<Vec<_>>()
488				};
489
490				self.prove_multilinear(
491					EvalcheckMultilinearClaim {
492						id: projected.id(),
493						eval_point: new_eval_point.into(),
494						eval,
495					},
496					transcript,
497				)?;
498			}
499			MultilinearPolyVariant::Shifted { .. } | MultilinearPolyVariant::Packed { .. } => {
500				let claim = EvalcheckMultilinearClaim {
501					id,
502					eval_point,
503					eval,
504				};
505
506				self.sumcheck_claims.push(SumcheckClaims::Projected(claim));
507			}
508			MultilinearPolyVariant::Composite { .. } => {
509				let claim = EvalcheckMultilinearClaim {
510					id,
511					eval_point,
512					eval,
513				};
514
515				self.sumcheck_claims.push(SumcheckClaims::Composite(claim));
516			}
517			MultilinearPolyVariant::LinearCombination(linear_combination) => {
518				for suboracle_id in linear_combination.polys() {
519					if let Some(claim_index) = self.claim_to_index.get(suboracle_id, &eval_point) {
520						serialize_evalcheck_proof(
521							&mut transcript.message(),
522							&EvalcheckHint::DuplicateClaim(*claim_index as u32),
523						);
524					} else {
525						serialize_evalcheck_proof(
526							&mut transcript.message(),
527							&EvalcheckHint::NewClaim,
528						);
529
530						let eval = *self
531							.evals_memoization
532							.get(suboracle_id, &eval_point)
533							.expect("precomputed above");
534
535						transcript.message().write_scalar(eval);
536
537						self.prove_multilinear_skip_duplicate_check(
538							EvalcheckMultilinearClaim {
539								id: suboracle_id,
540								eval_point: eval_point.clone(),
541								eval,
542							},
543							transcript,
544						)?;
545					}
546				}
547			}
548			MultilinearPolyVariant::ZeroPadded(padded) => {
549				let inner_eval_point = chain!(
550					&eval_point[..padded.start_index()],
551					&eval_point[padded.start_index() + padded.n_pad_vars()..],
552				)
553				.copied()
554				.collect::<Vec<_>>();
555
556				let inner_eval = *self
557					.evals_memoization
558					.get(padded.id(), &inner_eval_point)
559					.expect("precomputed above");
560
561				self.prove_multilinear(
562					EvalcheckMultilinearClaim {
563						id: padded.id(),
564						eval_point: inner_eval_point.into(),
565						eval: inner_eval,
566					},
567					transcript,
568				)?;
569			}
570		}
571		Ok(())
572	}
573
574	fn projected_bivariate_meta(
575		oracles: &mut MultilinearOracleSet<F>,
576		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
577	) -> Result<ProjectedBivariateMeta, Error> {
578		let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
579
580		match &oracles.oracle(*id).variant {
581			MultilinearPolyVariant::Shifted(shifted) => {
582				shifted_sumcheck_meta(oracles, shifted, eval_point)
583			}
584			MultilinearPolyVariant::Packed(packed) => {
585				packed_sumcheck_meta(oracles, packed, eval_point)
586			}
587			_ => unreachable!(),
588		}
589	}
590
591	fn process_bivariate_sumcheck(
592		&mut self,
593		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
594		meta: &ProjectedBivariateMeta,
595		projected: Option<MultilinearExtension<P>>,
596	) -> Result<(), Error> {
597		let EvalcheckMultilinearClaim {
598			id,
599			eval_point,
600			eval,
601		} = evalcheck_claim;
602
603		match self.oracles.oracle(*id).variant {
604			MultilinearPolyVariant::Shifted(shifted) => process_shifted_sumcheck(
605				&shifted,
606				meta,
607				eval_point,
608				*eval,
609				self.witness_index,
610				&mut self.new_sumchecks_constraints,
611				projected,
612			),
613
614			MultilinearPolyVariant::Packed(packed) => process_packed_sumcheck(
615				self.oracles,
616				&packed,
617				meta,
618				eval_point,
619				*eval,
620				self.witness_index,
621				&mut self.new_sumchecks_constraints,
622				projected,
623			),
624
625			_ => unreachable!(),
626		}
627	}
628
629	fn process_composite_mlecheck(
630		&mut self,
631		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
632		meta: CompositeMLECheckMeta,
633	) -> Result<(), Error> {
634		let EvalcheckMultilinearClaim {
635			id,
636			eval_point: _,
637			eval,
638		} = evalcheck_claim;
639
640		match self.oracles.oracle(*id).variant {
641			MultilinearPolyVariant::Composite(composite) => {
642				// witness for eq MLE has been previously filled in `fill_eq_witness_for_composites`
643				add_composite_sumcheck_to_constraints(
644					meta,
645					&mut self.new_sumchecks_constraints,
646					&composite,
647					*eval,
648				);
649				Ok(())
650			}
651			_ => unreachable!(),
652		}
653	}
654
655	/// Function that queries the witness data from the oracle id and evaluation point to find the evaluation of the multilinear
656	#[instrument(
657		skip_all,
658		name = "EvalcheckProverState::make_new_eval_claim",
659		level = "debug"
660	)]
661	fn make_new_eval_claim(
662		oracle_id: OracleId,
663		eval_point: EvalPoint<F>,
664		witness_index: &MultilinearExtensionIndex<P>,
665		memoized_queries: &MemoizedData<P, Backend>,
666	) -> Result<EvalcheckMultilinearClaim<F>, Error> {
667		let eval_query = memoized_queries
668			.full_query_readonly(&eval_point)
669			.ok_or(Error::MissingQuery)?;
670
671		let witness_poly = witness_index
672			.get_multilin_poly(oracle_id)
673			.map_err(Error::Witness)?;
674
675		let eval = witness_poly
676			.evaluate(eval_query.to_ref())
677			.map_err(Error::from)?;
678
679		Ok(EvalcheckMultilinearClaim {
680			id: oracle_id,
681			eval_point,
682			eval,
683		})
684	}
685}