binius_core/protocols/evalcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::collections::HashSet;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_math::MultilinearExtension;
7use binius_maybe_rayon::prelude::*;
8use getset::{Getters, MutGetters};
9use itertools::{chain, izip};
10use tracing::instrument;
11
12use super::{
13	EvalPoint, EvalPointOracleIdMap,
14	error::Error,
15	evalcheck::{EvalcheckHint, EvalcheckMultilinearClaim},
16	serialize_evalcheck_proof,
17	subclaims::{
18		MemoizedData, OracleIdPartialEval, ProjectedBivariateMeta,
19		add_composite_sumcheck_to_constraints, collect_projected_mles,
20	},
21};
22use crate::{
23	fiat_shamir::Challenger,
24	oracle::{
25		ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, MultilinearPolyVariant,
26		OracleId, SizedConstraintSet,
27	},
28	polynomial::MultivariatePoly,
29	protocols::evalcheck::{
30		logging::MLEFoldHighDimensionsData,
31		subclaims::{
32			packed_sumcheck_meta, process_packed_sumcheck, process_shifted_sumcheck,
33			shifted_sumcheck_meta,
34		},
35	},
36	transcript::ProverTranscript,
37	transparent::select_row::SelectRow,
38	witness::MultilinearExtensionIndex,
39};
40
41/// A mutable prover state.
42///
43/// Can be persisted across [`EvalcheckProver::prove`] invocations. Accumulates
44/// `new_sumchecks` bivariate sumcheck instances, as well as holds mutable references to
45/// the trace (to which new oracles & multilinears may be added during proving)
46#[derive(Getters, MutGetters)]
47pub struct EvalcheckProver<'a, 'b, F, P>
48where
49	P: PackedField<Scalar = F>,
50	F: TowerField,
51{
52	/// Mutable reference to the oracle set which is modified to create new claims arising from
53	/// sumchecks
54	pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
55	/// Mutable reference to the witness index which is is populated by the prover for new claims
56	/// arising from sumchecks
57	pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
58
59	/// The committed evaluation claims arising in this round
60	#[getset(get = "pub", get_mut = "pub")]
61	committed_eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
62
63	// Claims that need to be evaluated.
64	claims_to_be_evaluated: HashSet<(OracleId, EvalPoint<F>)>,
65
66	// Claims that can be evaluated using internal_evals.
67	claims_without_evals: HashSet<(OracleId, EvalPoint<F>)>,
68
69	// The list of claims that reduces to a bivariate sumcheck in a round.
70	projected_bivariate_claims: Vec<EvalcheckMultilinearClaim<F>>,
71
72	// The new bivariate sumcheck constraints arising in this round
73	new_bivariate_sumchecks_constraints: Vec<ConstraintSetBuilder<F>>,
74	// The new mle sumcheck constraints arising in this round
75	new_mlechecks_constraints: Vec<(EvalPoint<F>, ConstraintSetBuilder<F>)>,
76	// Tensor expansion of evaluation points and partial evaluations of multilinears
77	pub memoized_data: MemoizedData<'b, P>,
78
79	// The unique index of a claim in this round.
80	claim_to_index: EvalPointOracleIdMap<usize, F>,
81	// Claims that have been visited in this round, used to deduplicate claims when collecting
82	// subclaims.
83	visited_claims: EvalPointOracleIdMap<(), F>,
84	// Memoization of evaluations of claims the prover sees in this round
85	evals_memoization: EvalPointOracleIdMap<F, F>,
86	// The index of the next claim to be verified
87	round_claim_index: usize,
88
89	// Partial evaluations after `evaluate_partial_high` in suffixes
90	partial_evals: EvalPointOracleIdMap<MultilinearExtension<P>, F>,
91
92	// Common suffixes
93	suffixes: HashSet<EvalPoint<F>>,
94}
95
96impl<'a, 'b, F, P> EvalcheckProver<'a, 'b, F, P>
97where
98	P: PackedField<Scalar = F>,
99	F: TowerField,
100{
101	/// Create a new prover state by tying together the mutable references to the oracle set and
102	/// witness index (they need to be mutable because `new_sumcheck` reduction may add new oracles
103	/// & multilinears) as well as committed eval claims accumulator.
104	pub fn new(
105		oracles: &'a mut MultilinearOracleSet<F>,
106		witness_index: &'a mut MultilinearExtensionIndex<'b, P>,
107	) -> Self {
108		Self {
109			oracles,
110			witness_index,
111			committed_eval_claims: Vec::new(),
112			new_bivariate_sumchecks_constraints: Vec::new(),
113			new_mlechecks_constraints: Vec::new(),
114			claims_without_evals: HashSet::new(),
115			claims_to_be_evaluated: HashSet::new(),
116			projected_bivariate_claims: Vec::new(),
117			memoized_data: MemoizedData::new(),
118
119			claim_to_index: EvalPointOracleIdMap::new(),
120			visited_claims: EvalPointOracleIdMap::new(),
121			evals_memoization: EvalPointOracleIdMap::new(),
122			round_claim_index: 0,
123
124			partial_evals: EvalPointOracleIdMap::new(),
125			suffixes: HashSet::new(),
126		}
127	}
128
129	/// A helper method to move out bivariate sumcheck constraints
130	pub fn take_new_bivariate_sumchecks_constraints(
131		&mut self,
132	) -> Result<Vec<SizedConstraintSet<F>>, OracleError> {
133		self.new_bivariate_sumchecks_constraints
134			.iter_mut()
135			.map(|builder| std::mem::take(builder).build_one(self.oracles))
136			.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
137			.collect()
138	}
139
140	/// A helper method to move out mlechecks constraints
141	pub fn take_new_mlechecks_constraints(
142		&mut self,
143	) -> Result<Vec<ConstraintSetEqIndPoint<F>>, OracleError> {
144		std::mem::take(&mut self.new_mlechecks_constraints)
145			.into_iter()
146			.map(|(ep, builder)| {
147				builder
148					.build_one(self.oracles)
149					.map(|constraint| ConstraintSetEqIndPoint {
150						eq_ind_challenges: ep.clone(),
151						constraint_set: constraint,
152					})
153			})
154			.collect()
155	}
156
157	/// Prove an evalcheck claim.
158	///
159	/// Given a prover state containing [`MultilinearOracleSet`] indexing into given
160	/// [`MultilinearExtensionIndex`], we prove an [`EvalcheckMultilinearClaim`] (stating that given
161	/// composite `poly` equals `eval` at `eval_point`) by recursively processing each of the
162	/// multilinears. This way the evalcheck claim gets transformed into an [`EvalcheckHint`]
163	/// and a new set of claims on:
164	///  * Committed polynomial evaluations
165	///  * New sumcheck constraints that need to be proven in subsequent rounds (those get appended
166	///    to `new_sumchecks`)
167	///
168	/// All of the `new_sumchecks` constraints follow the same pattern:
169	///  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
170	///  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
171	///  * other multilin is a projection of one of the evalcheck claim multilins to its first
172	///    variables
173	pub fn prove<Challenger_: Challenger>(
174		&mut self,
175		evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
176		transcript: &mut ProverTranscript<Challenger_>,
177	) -> Result<(), Error> {
178		// Reset the prover state for a new round.
179		self.round_claim_index = 0;
180		self.visited_claims.clear();
181		self.claim_to_index.clear();
182		self.evals_memoization.clear();
183
184		let mle_fold_full_span = tracing::debug_span!(
185			"[task] MLE Fold Full",
186			phase = "evalcheck",
187			perfetto_category = "task.main"
188		)
189		.entered();
190
191		// Step 1: Precompute claims that require additional evaluations.
192		for claim in &evalcheck_claims {
193			self.collect_subclaims_for_memoization(
194				claim.id,
195				claim.eval_point.clone(),
196				Some(claim.eval),
197			);
198		}
199
200		let mut eval_points = self
201			.claims_to_be_evaluated
202			.iter()
203			.map(|(_, eval_point)| eval_point.clone())
204			.collect::<HashSet<_>>();
205
206		let mut suffixes = self.suffixes.iter().cloned().collect::<HashSet<_>>();
207
208		let mut prefixes = HashSet::new();
209
210		let mut to_remove = Vec::new();
211		for suffix in &suffixes {
212			for eval_point in &eval_points {
213				if let Some(prefix) = eval_point.try_get_prefix(suffix) {
214					prefixes.insert(prefix);
215					to_remove.push(eval_point.clone());
216				}
217			}
218		}
219		for ep in to_remove {
220			eval_points.remove(&ep);
221		}
222
223		// We don't split points whose tensor product, when halved, would be smaller than
224		// PackedField::WIDTH.
225		let (long, short): (Vec<_>, Vec<_>) = eval_points
226			.into_iter()
227			.partition(|ep| ep.len().saturating_sub(1) > P::LOG_WIDTH);
228
229		for eval_point in long {
230			let ep = eval_point.to_vec();
231			let mid = ep.len() / 2;
232			let (low, high) = ep.split_at(mid);
233			let suffix = EvalPoint::from(high);
234			let prefix = EvalPoint::from(low);
235			suffixes.insert(suffix.clone());
236			self.suffixes.insert(suffix);
237			prefixes.insert(prefix);
238		}
239
240		let eval_points = chain!(&short, &suffixes, &prefixes)
241			.map(|p| p.as_ref())
242			.collect::<Vec<_>>();
243
244		self.memoized_data.memoize_query_par(eval_points)?;
245
246		let subclaims_partial_evals = std::mem::take(&mut self.claims_to_be_evaluated)
247			.into_par_iter()
248			.map(|(id, eval_point)| {
249				Self::make_new_eval_claim(
250					id,
251					eval_point,
252					self.witness_index,
253					&self.memoized_data,
254					&self.partial_evals,
255					&self.suffixes,
256				)
257			})
258			.collect::<Result<Vec<_>, Error>>()?;
259
260		let (subclaims, partial_evals): (Vec<_>, Vec<_>) =
261			subclaims_partial_evals.into_iter().unzip();
262
263		for OracleIdPartialEval {
264			id,
265			suffix,
266			partial_eval,
267		} in partial_evals.into_iter().flatten()
268		{
269			self.partial_evals.insert(id, suffix, partial_eval);
270		}
271
272		for subclaim in &subclaims {
273			self.evals_memoization
274				.insert(subclaim.id, subclaim.eval_point.clone(), subclaim.eval);
275		}
276
277		let mut claims_without_evals = std::mem::take(&mut self.claims_without_evals)
278			.into_iter()
279			.collect::<Vec<_>>();
280
281		claims_without_evals.sort_unstable_by_key(|(id, _)| *id);
282
283		for (id, eval_point) in claims_without_evals {
284			self.collect_evals(id, &eval_point);
285		}
286
287		drop(mle_fold_full_span);
288
289		// Step 2: Prove multilinears: For each claim, we prove the claim by recursively proving the
290		// subclaims by stepping through subclaims in a DFS manner and deduplicating claims.
291		for claim in evalcheck_claims {
292			self.prove_multilinear(claim, transcript)?;
293		}
294
295		// Step 3: Process projected_bivariate_claims
296		let dimensions_data = MLEFoldHighDimensionsData::new(self.projected_bivariate_claims.len());
297		let evalcheck_mle_fold_high_span = tracing::debug_span!(
298			"[task] (Evalcheck) MLE Fold High",
299			phase = "evalcheck",
300			perfetto_category = "task.main",
301			?dimensions_data,
302		)
303		.entered();
304
305		let projected_bivariate_metas = self
306			.projected_bivariate_claims
307			.iter()
308			.map(|claim| Self::projected_bivariate_meta(self.oracles, claim))
309			.collect::<Result<Vec<_>, Error>>()?;
310
311		let projected_bivariate_claims = std::mem::take(&mut self.projected_bivariate_claims);
312
313		collect_projected_mles(
314			&projected_bivariate_metas,
315			&mut self.memoized_data,
316			&projected_bivariate_claims,
317			self.oracles,
318			self.witness_index,
319			&mut self.partial_evals,
320		)?;
321
322		drop(evalcheck_mle_fold_high_span);
323
324		// memoize eq_ind_partial_evals for HighToLow case
325		self.memoized_data
326			.memoize_query_par(self.new_mlechecks_constraints.iter().map(|(ep, _)| {
327				let ep = ep.as_ref();
328				&ep[0..ep.len().saturating_sub(1)]
329			}))?;
330
331		for (claim, meta) in izip!(&projected_bivariate_claims, &projected_bivariate_metas) {
332			self.process_bivariate_sumcheck(claim, meta)?;
333		}
334
335		self.memoized_data.memoize_partial_evals(
336			&projected_bivariate_metas,
337			&projected_bivariate_claims,
338			self.oracles,
339			self.witness_index,
340		);
341
342		Ok(())
343	}
344
345	#[instrument(
346		skip_all,
347		name = "EvalcheckProverState::collect_subclaims_for_memoization",
348		level = "debug"
349	)]
350	fn collect_subclaims_for_memoization(
351		&mut self,
352		multilinear_id: OracleId,
353		eval_point: EvalPoint<F>,
354		eval: Option<F>,
355	) {
356		if self.visited_claims.contains(multilinear_id, &eval_point) {
357			return;
358		}
359
360		self.visited_claims
361			.insert(multilinear_id, eval_point.clone(), ());
362
363		if let Some(eval) = eval {
364			self.evals_memoization
365				.insert(multilinear_id, eval_point.clone(), eval);
366		}
367
368		let multilinear = &self.oracles[multilinear_id];
369
370		match multilinear.variant {
371			MultilinearPolyVariant::Shifted(_) => {
372				if !self.evals_memoization.contains(multilinear_id, &eval_point) {
373					self.collect_suffixes(multilinear_id, eval_point.clone());
374
375					self.claims_to_be_evaluated
376						.insert((multilinear_id, eval_point));
377				}
378			}
379
380			MultilinearPolyVariant::Repeating { id, .. } => {
381				let n_vars = self.oracles.n_vars(id);
382				let inner_eval_point = eval_point.slice(0..n_vars);
383				self.collect_subclaims_for_memoization(id, inner_eval_point, eval);
384			}
385
386			MultilinearPolyVariant::Projected(ref projected) => {
387				let (id, values) = (projected.id(), projected.values());
388
389				let new_eval_point = {
390					let idx = projected.start_index();
391					let mut new_eval_point = eval_point[0..idx].to_vec();
392					new_eval_point.extend(values.clone());
393					new_eval_point.extend(eval_point[idx..].to_vec());
394					new_eval_point
395				};
396
397				self.collect_subclaims_for_memoization(id, new_eval_point.into(), eval);
398			}
399
400			MultilinearPolyVariant::LinearCombination(ref linear_combination) => {
401				let n_polys = linear_combination.n_polys();
402				let next =
403					izip!(linear_combination.polys(), linear_combination.coefficients()).next();
404				match (next, eval) {
405					(Some((suboracle_id, coeff)), Some(eval))
406						if n_polys == 1 && !coeff.is_zero() =>
407					{
408						let eval = if let Some(eval) =
409							self.evals_memoization.get(suboracle_id, &eval_point)
410						{
411							*eval
412						} else {
413							let eval = (eval - linear_combination.offset())
414								* coeff.invert().expect("not zero");
415							self.evals_memoization
416								.insert(suboracle_id, eval_point.clone(), eval);
417							eval
418						};
419
420						self.collect_subclaims_for_memoization(
421							suboracle_id,
422							eval_point,
423							Some(eval),
424						);
425					}
426					_ => {
427						// We have to collect here to make the borrowck happy. This does not seem
428						// to be a big problem, but in case it turns out to be problematic, consider
429						// using smallvec.
430						let lincom_suboracles = linear_combination.polys().collect::<Vec<_>>();
431						for suboracle_id in lincom_suboracles {
432							self.claims_without_evals
433								.insert((suboracle_id, eval_point.clone()));
434
435							self.collect_subclaims_for_memoization(
436								suboracle_id,
437								eval_point.clone(),
438								None,
439							);
440						}
441					}
442				};
443			}
444
445			MultilinearPolyVariant::ZeroPadded(ref padded) => {
446				let id = padded.id();
447				let inner_eval_point = chain!(
448					&eval_point[..padded.start_index()],
449					&eval_point[padded.start_index() + padded.n_pad_vars()..],
450				)
451				.copied()
452				.collect::<Vec<_>>();
453				let inner_eval_point = EvalPoint::from(inner_eval_point);
454
455				self.claims_without_evals
456					.insert((id, inner_eval_point.clone()));
457
458				self.collect_subclaims_for_memoization(id, inner_eval_point, None);
459			}
460			_ => {
461				if !self.evals_memoization.contains(multilinear_id, &eval_point) {
462					self.claims_to_be_evaluated
463						.insert((multilinear_id, eval_point));
464				}
465			}
466		};
467	}
468
469	#[instrument(
470		skip_all,
471		name = "EvalcheckProverState::prove_multilinear",
472		level = "debug"
473	)]
474	fn prove_multilinear<Challenger_: Challenger>(
475		&mut self,
476		evalcheck_claim: EvalcheckMultilinearClaim<F>,
477		transcript: &mut ProverTranscript<Challenger_>,
478	) -> Result<(), Error> {
479		let EvalcheckMultilinearClaim { id, eval_point, .. } = &evalcheck_claim;
480
481		let claim_id = self.claim_to_index.get(*id, eval_point);
482
483		if let Some(claim_id) = claim_id {
484			serialize_evalcheck_proof(
485				&mut transcript.message(),
486				&EvalcheckHint::DuplicateClaim(*claim_id as u32),
487			);
488			return Ok(());
489		}
490		serialize_evalcheck_proof(&mut transcript.message(), &EvalcheckHint::NewClaim);
491
492		self.prove_multilinear_skip_duplicate_check(evalcheck_claim, transcript)
493	}
494
495	fn prove_multilinear_skip_duplicate_check<Challenger_: Challenger>(
496		&mut self,
497		evalcheck_claim: EvalcheckMultilinearClaim<F>,
498		transcript: &mut ProverTranscript<Challenger_>,
499	) -> Result<(), Error> {
500		let EvalcheckMultilinearClaim {
501			id,
502			eval_point,
503			eval,
504		} = evalcheck_claim;
505
506		self.claim_to_index
507			.insert(id, eval_point.clone(), self.round_claim_index);
508
509		self.round_claim_index += 1;
510
511		let multilinear = &self.oracles[id];
512
513		match multilinear.variant {
514			MultilinearPolyVariant::Transparent { .. } | MultilinearPolyVariant::Structured(_) => {}
515			MultilinearPolyVariant::Committed => {
516				self.committed_eval_claims.push(EvalcheckMultilinearClaim {
517					id: multilinear.id,
518					eval_point,
519					eval,
520				});
521			}
522			MultilinearPolyVariant::Repeating {
523				id: inner_id,
524				log_count,
525			} => {
526				let n_vars = eval_point.len() - log_count;
527				self.prove_multilinear(
528					EvalcheckMultilinearClaim {
529						id: inner_id,
530						eval_point: eval_point.slice(0..n_vars),
531						eval,
532					},
533					transcript,
534				)?;
535			}
536			MultilinearPolyVariant::Projected(ref projected) => {
537				let new_eval_point = {
538					let (lo, hi) = eval_point.split_at(projected.start_index());
539					chain!(lo, projected.values(), hi)
540						.copied()
541						.collect::<Vec<_>>()
542				};
543
544				self.prove_multilinear(
545					EvalcheckMultilinearClaim {
546						id: projected.id(),
547						eval_point: new_eval_point.into(),
548						eval,
549					},
550					transcript,
551				)?;
552			}
553			MultilinearPolyVariant::Shifted { .. } | MultilinearPolyVariant::Packed { .. } => {
554				let claim = EvalcheckMultilinearClaim {
555					id,
556					eval_point,
557					eval,
558				};
559				self.projected_bivariate_claims.push(claim);
560			}
561			MultilinearPolyVariant::Composite(ref composite) => {
562				let position = self
563					.new_mlechecks_constraints
564					.iter()
565					.position(|(ep, _)| *ep == eval_point)
566					.unwrap_or(self.new_mlechecks_constraints.len());
567
568				transcript.message().write(&(position as u32));
569
570				add_composite_sumcheck_to_constraints(
571					position,
572					&eval_point,
573					&mut self.new_mlechecks_constraints,
574					composite,
575					eval,
576				);
577			}
578			MultilinearPolyVariant::LinearCombination(ref linear_combination) => {
579				let lincom_suboracles = linear_combination.polys().collect::<Vec<_>>();
580				for suboracle_id in lincom_suboracles {
581					if let Some(claim_index) = self.claim_to_index.get(suboracle_id, &eval_point) {
582						serialize_evalcheck_proof(
583							&mut transcript.message(),
584							&EvalcheckHint::DuplicateClaim(*claim_index as u32),
585						);
586					} else {
587						serialize_evalcheck_proof(
588							&mut transcript.message(),
589							&EvalcheckHint::NewClaim,
590						);
591
592						let eval = *self
593							.evals_memoization
594							.get(suboracle_id, &eval_point)
595							.expect("precomputed above");
596
597						transcript.message().write_scalar(eval);
598
599						self.prove_multilinear_skip_duplicate_check(
600							EvalcheckMultilinearClaim {
601								id: suboracle_id,
602								eval_point: eval_point.clone(),
603								eval,
604							},
605							transcript,
606						)?;
607					}
608				}
609			}
610			MultilinearPolyVariant::ZeroPadded(ref padded) => {
611				let inner_eval_point = chain!(
612					&eval_point[..padded.start_index()],
613					&eval_point[padded.start_index() + padded.n_pad_vars()..],
614				)
615				.copied()
616				.collect::<Vec<_>>();
617
618				let zs =
619					&eval_point[padded.start_index()..padded.start_index() + padded.n_pad_vars()];
620				let select_row = SelectRow::new(zs.len(), padded.nonzero_index())?;
621				let select_row_term = select_row
622					.evaluate(zs)
623					.expect("select_row is constructor with zs.len() variables");
624
625				if eval.is_zero() && select_row_term.is_zero() {
626					return Ok(());
627				}
628
629				let inner_eval = *self
630					.evals_memoization
631					.get(padded.id(), &inner_eval_point)
632					.expect("precomputed above");
633
634				self.prove_multilinear(
635					EvalcheckMultilinearClaim {
636						id: padded.id(),
637						eval_point: inner_eval_point.into(),
638						eval: inner_eval,
639					},
640					transcript,
641				)?;
642			}
643		}
644		Ok(())
645	}
646
647	pub fn collect_evals(&mut self, oracle_id: OracleId, eval_point: &EvalPoint<F>) -> F {
648		if let Some(eval) = self.evals_memoization.get(oracle_id, eval_point) {
649			return *eval;
650		}
651
652		let eval = match &self.oracles[oracle_id].variant {
653			MultilinearPolyVariant::Repeating { id, log_count } => {
654				let n_vars = eval_point.len() - log_count;
655				self.collect_evals(*id, &eval_point.slice(0..n_vars))
656			}
657			MultilinearPolyVariant::Projected(projected) => {
658				let new_eval_point = {
659					let (lo, hi) = eval_point.split_at(projected.start_index());
660					chain!(lo, projected.values(), hi)
661						.copied()
662						.collect::<Vec<_>>()
663				};
664				self.collect_evals(projected.id(), &new_eval_point.into())
665			}
666			MultilinearPolyVariant::LinearCombination(linear_combination) => {
667				let ids = linear_combination.polys().collect::<Vec<_>>();
668
669				let coeffs = linear_combination.coefficients().collect::<Vec<_>>();
670				let offset = linear_combination.offset();
671
672				let mut evals = Vec::with_capacity(ids.len());
673
674				for id in &ids {
675					evals.push(self.collect_evals(*id, eval_point));
676				}
677
678				izip!(evals, coeffs).fold(offset, |acc, (eval, coeff)| {
679					if coeff.is_zero() {
680						return acc;
681					}
682
683					acc + eval * coeff
684				})
685			}
686			MultilinearPolyVariant::ZeroPadded(padded) => {
687				let subclaim_eval_point = chain!(
688					&eval_point[..padded.start_index()],
689					&eval_point[padded.start_index() + padded.n_pad_vars()..],
690				)
691				.copied()
692				.collect::<Vec<_>>();
693
694				let zs =
695					&eval_point[padded.start_index()..padded.start_index() + padded.n_pad_vars()];
696				let select_row = SelectRow::new(zs.len(), padded.nonzero_index())
697					.expect("SelectRow receives the correct parameters");
698				let select_row_term = select_row
699					.evaluate(zs)
700					.expect("select_row is constructor with zs.len() variables");
701
702				let eval = self.collect_evals(padded.id(), &subclaim_eval_point.into());
703
704				eval * select_row_term
705			}
706			_ => unreachable!(),
707		};
708
709		self.evals_memoization
710			.insert(oracle_id, eval_point.clone(), eval);
711		eval
712	}
713
714	fn projected_bivariate_meta(
715		oracles: &mut MultilinearOracleSet<F>,
716		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
717	) -> Result<ProjectedBivariateMeta, Error> {
718		let EvalcheckMultilinearClaim { id, eval_point, .. } = evalcheck_claim;
719
720		match &oracles[*id].variant {
721			MultilinearPolyVariant::Shifted(shifted) => {
722				let shifted = shifted.clone();
723				shifted_sumcheck_meta(oracles, &shifted, eval_point)
724			}
725			MultilinearPolyVariant::Packed(packed) => {
726				let packed = packed.clone();
727				packed_sumcheck_meta(oracles, &packed, eval_point)
728			}
729			_ => unreachable!(),
730		}
731	}
732
733	/// Collect common suffixes to reuse `partial_evals` with different prefixes.
734	pub fn collect_suffixes(&mut self, oracle_id: OracleId, suffix: EvalPoint<F>) {
735		let multilinear = &self.oracles[oracle_id];
736
737		match &multilinear.variant {
738			MultilinearPolyVariant::Projected(projected) => {
739				let new_eval_point_len = multilinear.n_vars + projected.values().len();
740
741				let range = new_eval_point_len - projected.start_index().max(suffix.len())
742					..new_eval_point_len;
743
744				self.collect_suffixes(projected.id(), suffix.slice(range));
745			}
746			MultilinearPolyVariant::Shifted(shifted) => {
747				let suffix_len = multilinear.n_vars - shifted.block_size();
748				if suffix_len <= suffix.len() {
749					self.collect_suffixes(
750						shifted.id(),
751						suffix.slice(suffix.len() - suffix_len..suffix.len()),
752					);
753				}
754			}
755			MultilinearPolyVariant::LinearCombination(linear_combination) => {
756				let ids = linear_combination.polys().collect::<Vec<_>>();
757
758				for id in ids {
759					self.collect_suffixes(id, suffix.clone());
760				}
761			}
762			_ => {
763				self.suffixes.insert(suffix);
764			}
765		}
766	}
767
768	fn process_bivariate_sumcheck(
769		&mut self,
770		evalcheck_claim: &EvalcheckMultilinearClaim<F>,
771		meta: &ProjectedBivariateMeta,
772	) -> Result<(), Error> {
773		let EvalcheckMultilinearClaim {
774			id,
775			eval_point,
776			eval,
777		} = evalcheck_claim;
778
779		match self.oracles[*id].variant {
780			MultilinearPolyVariant::Shifted(ref shifted) => process_shifted_sumcheck(
781				shifted,
782				meta,
783				eval_point,
784				*eval,
785				self.witness_index,
786				&mut self.new_bivariate_sumchecks_constraints,
787				&self.partial_evals,
788			),
789
790			MultilinearPolyVariant::Packed(ref packed) => process_packed_sumcheck(
791				self.oracles,
792				packed,
793				meta,
794				eval_point,
795				*eval,
796				self.witness_index,
797				&mut self.new_bivariate_sumchecks_constraints,
798				&self.partial_evals,
799			),
800
801			_ => unreachable!(),
802		}
803	}
804
805	/// Function that queries the witness data from the oracle id and evaluation point to find
806	/// the evaluation of the multilinear
807	#[instrument(
808		skip_all,
809		name = "EvalcheckProverState::make_new_eval_claim",
810		level = "debug"
811	)]
812	fn make_new_eval_claim(
813		oracle_id: OracleId,
814		eval_point: EvalPoint<F>,
815		witness_index: &MultilinearExtensionIndex<P>,
816		memoized_queries: &MemoizedData<P>,
817		partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
818		suffixes: &HashSet<EvalPoint<F>>,
819	) -> Result<(EvalcheckMultilinearClaim<F>, Option<OracleIdPartialEval<P>>), Error> {
820		let witness_poly = witness_index
821			.get_multilin_poly(oracle_id)
822			.map_err(Error::Witness)?;
823
824		let mut eval = None;
825		let mut new_partial_eval = None;
826
827		for suffix in suffixes {
828			if let Some(prefix) = eval_point.try_get_prefix(suffix) {
829				let partial_eval = match partial_evals.get(oracle_id, suffix) {
830					Some(partial_eval) => partial_eval,
831					None => {
832						let suffix_query = memoized_queries
833							.full_query_readonly(suffix)
834							.ok_or(Error::MissingQuery)?;
835
836						let partial_eval = witness_poly
837							.evaluate_partial_high(suffix_query.to_ref())
838							.map_err(Error::from)?;
839
840						new_partial_eval = Some(OracleIdPartialEval {
841							id: oracle_id,
842							suffix: suffix.clone(),
843							partial_eval,
844						});
845						&new_partial_eval
846							.as_ref()
847							.expect("new_partial_eval added above")
848							.partial_eval
849					}
850				};
851
852				let prefix_query = memoized_queries
853					.full_query_readonly(&prefix)
854					.ok_or(Error::MissingQuery)?;
855
856				eval = Some(partial_eval.evaluate(prefix_query).map_err(Error::from)?);
857				break;
858			}
859		}
860
861		let eval = match eval {
862			Some(value) => value,
863			None => {
864				let query = memoized_queries
865					.full_query_readonly(&eval_point)
866					.ok_or(Error::MissingQuery)?;
867
868				witness_poly.evaluate(query.to_ref()).map_err(Error::from)?
869			}
870		};
871
872		let claim = EvalcheckMultilinearClaim {
873			id: oracle_id,
874			eval_point,
875			eval,
876		};
877
878		Ok((claim, new_partial_eval))
879	}
880}
881
882pub struct ConstraintSetEqIndPoint<F: Field> {
883	pub eq_ind_challenges: EvalPoint<F>,
884	pub constraint_set: SizedConstraintSet<F>,
885}