binius_core/protocols/evalcheck/
subclaims.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! This module contains helpers to create bivariate sumcheck instances originating from:
4//!  * products with shift indicators (shifted virtual polynomials)
5//!  * products with tower basis (packed virtual polynomials)
6//!
7//! All of them have common traits:
8//!  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
9//!  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
10//!  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
11
12use std::collections::HashSet;
13
14use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
15use binius_hal::ComputationBackend;
16use binius_math::{
17	ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
18	MultilinearExtension, MultilinearQuery,
19};
20use binius_maybe_rayon::prelude::*;
21use binius_utils::bail;
22use bytemuck::zeroed_vec;
23use itertools::izip;
24use tracing::instrument;
25
26use super::{EvalPoint, EvalPointOracleIdMap, error::Error, evalcheck::EvalcheckMultilinearClaim};
27use crate::{
28	fiat_shamir::Challenger,
29	oracle::{
30		CompositeMLE, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId,
31		Packed, Shifted, SizedConstraintSet,
32	},
33	polynomial::MultivariatePoly,
34	protocols::sumcheck::{
35		self, Error as SumcheckError,
36		prove::{
37			front_loaded,
38			oracles::{
39				MLECheckProverWithMeta, SumcheckProversWithMetas,
40				constraint_sets_mlecheck_prover_meta, constraint_sets_sumcheck_provers_metas,
41			},
42		},
43	},
44	transcript::ProverTranscript,
45	transparent::{shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis},
46	witness::{MultilinearExtensionIndex, MultilinearWitness},
47};
48
49/// Create oracles for the bivariate product of an inner oracle with shift indicator.
50///
51/// Projects to first `block_size()` vars.
52pub fn shifted_sumcheck_meta<F: TowerField>(
53	oracles: &mut MultilinearOracleSet<F>,
54	shifted: &Shifted,
55	eval_point: &[F],
56) -> Result<ProjectedBivariateMeta, Error> {
57	projected_bivariate_meta(
58		oracles,
59		shifted.id(),
60		shifted.block_size(),
61		eval_point,
62		|projected_eval_point| {
63			Ok(ShiftIndPartialEval::new(
64				shifted.block_size(),
65				shifted.shift_offset(),
66				shifted.shift_variant(),
67				projected_eval_point.to_vec(),
68			)?)
69		},
70	)
71}
72
73/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck
74/// constraint to the [`ConstraintSetBuilder`]
75#[allow(clippy::too_many_arguments)]
76pub fn process_shifted_sumcheck<F, P>(
77	shifted: &Shifted,
78	meta: &ProjectedBivariateMeta,
79	eval_point: &[F],
80	eval: F,
81	witness_index: &mut MultilinearExtensionIndex<P>,
82	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
83	partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
84) -> Result<(), Error>
85where
86	P: PackedField<Scalar = F>,
87	F: TowerField,
88{
89	process_projected_bivariate_witness(
90		witness_index,
91		meta,
92		eval_point,
93		|projected_eval_point| {
94			let shift_ind = ShiftIndPartialEval::new(
95				projected_eval_point.len(),
96				shifted.shift_offset(),
97				shifted.shift_variant(),
98				projected_eval_point.to_vec(),
99			)?;
100
101			let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
102			Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
103		},
104		partial_evals,
105	)?;
106	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
107
108	Ok(())
109}
110
111/// Create oracles for the bivariate product of an inner oracle with the tower basis.
112///
113/// Projects to first `log_degree()` vars.
114/// Returns metadata object with oracle identifiers.
115pub fn packed_sumcheck_meta<F: TowerField>(
116	oracles: &mut MultilinearOracleSet<F>,
117	packed: &Packed,
118	eval_point: &[F],
119) -> Result<ProjectedBivariateMeta, Error> {
120	let n_vars = oracles.n_vars(packed.id());
121	let log_degree = packed.log_degree();
122	let binary_tower_level = oracles[packed.id()].binary_tower_level();
123
124	if log_degree > n_vars {
125		bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
126	}
127
128	// NB. projected_n_vars = 0 because eval_point length is log_degree less than inner n_vars
129	projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
130		Ok(TowerBasis::new(log_degree, binary_tower_level)?)
131	})
132}
133
134pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
135	meta: &ProjectedBivariateMeta,
136	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
137	n_vars: usize,
138	eval: F,
139) {
140	if n_vars >= constraint_builders.len() {
141		constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
142	}
143	let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
144	constraint_builders[n_vars].add_sumcheck(meta.oracle_ids(), bivariate_product.into(), eval);
145}
146
147pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
148	position: usize,
149	eval_point: &EvalPoint<F>,
150	constraint_builders: &mut Vec<(EvalPoint<F>, ConstraintSetBuilder<F>)>,
151	comp: &CompositeMLE<F>,
152	eval: F,
153) {
154	let oracle_ids = comp.inner().clone();
155
156	if let Some((_, constraint_builder)) = constraint_builders.get_mut(position) {
157		constraint_builder.add_sumcheck(
158			oracle_ids,
159			<_ as CompositionPoly<F>>::expression(comp.c()),
160			eval,
161		);
162	} else {
163		let mut new_builder = ConstraintSetBuilder::new();
164		new_builder.add_sumcheck(oracle_ids, <_ as CompositionPoly<F>>::expression(comp.c()), eval);
165		constraint_builders.push((eval_point.clone(), new_builder));
166	}
167}
168
169/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck
170/// constraint to the [`ConstraintSetBuilder`]
171#[allow(clippy::too_many_arguments)]
172pub fn process_packed_sumcheck<F, P>(
173	oracles: &MultilinearOracleSet<F>,
174	packed: &Packed,
175	meta: &ProjectedBivariateMeta,
176	eval_point: &[F],
177	eval: F,
178	witness_index: &mut MultilinearExtensionIndex<P>,
179	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
180	partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
181) -> Result<(), Error>
182where
183	P: PackedField<Scalar = F>,
184	F: TowerField,
185{
186	let log_degree = packed.log_degree();
187	let binary_tower_level = oracles[packed.id()].binary_tower_level();
188
189	process_projected_bivariate_witness(
190		witness_index,
191		meta,
192		eval_point,
193		|_projected_eval_point| {
194			let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
195			let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
196			Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
197		},
198		partial_evals,
199	)?;
200
201	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
202	Ok(())
203}
204
205/// Metadata about a sumcheck over a bivariate product of two multilinears.
206#[derive(Clone, Copy)]
207pub struct ProjectedBivariateMeta {
208	inner_id: OracleId,
209	projected_id: Option<OracleId>,
210	multiplier_id: OracleId,
211	projected_n_vars: usize,
212}
213
214impl ProjectedBivariateMeta {
215	pub fn oracle_ids(&self) -> [OracleId; 2] {
216		[
217			self.projected_id.unwrap_or(self.inner_id),
218			self.multiplier_id,
219		]
220	}
221}
222
223fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
224	oracles: &mut MultilinearOracleSet<F>,
225	inner_id: OracleId,
226	projected_n_vars: usize,
227	eval_point: &[F],
228	multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
229) -> Result<ProjectedBivariateMeta, Error> {
230	let inner = &oracles[inner_id];
231
232	let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
233		let projected_id =
234			oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
235
236		(&eval_point[..projected_n_vars], Some(projected_id))
237	} else {
238		(eval_point, None)
239	};
240
241	let projected_n_vars = projected_eval_point.len();
242
243	let multiplier_id =
244		oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
245
246	let meta = ProjectedBivariateMeta {
247		inner_id,
248		projected_id,
249		multiplier_id,
250		projected_n_vars,
251	};
252
253	Ok(meta)
254}
255
256fn process_projected_bivariate_witness<'a, F, P>(
257	witness_index: &mut MultilinearExtensionIndex<'a, P>,
258	meta: &ProjectedBivariateMeta,
259	eval_point: &[F],
260	multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
261	partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
262) -> Result<(), Error>
263where
264	P: PackedField<Scalar = F>,
265	F: TowerField,
266{
267	let &ProjectedBivariateMeta {
268		projected_id,
269		multiplier_id,
270		projected_n_vars,
271		inner_id,
272	} = meta;
273
274	let projected_eval_point = if let Some(projected_id) = projected_id {
275		let (prefix, suffix) = eval_point.split_at(projected_n_vars);
276
277		let projected = partial_evals
278			.get(inner_id, suffix)
279			.expect("projected should exist if projected_id exist")
280			.clone();
281
282		witness_index.update_multilin_poly(vec![(
283			projected_id,
284			MLEDirectAdapter::from(projected).upcast_arc_dyn(),
285		)])?;
286		prefix
287	} else {
288		eval_point
289	};
290
291	let m = multiplier_witness_ctr(projected_eval_point)?;
292
293	if !witness_index.has(multiplier_id) {
294		witness_index.update_multilin_poly([(multiplier_id, m)])?;
295	}
296	Ok(())
297}
298
299pub struct OracleIdPartialEval<P: PackedField> {
300	pub id: OracleId,
301	pub suffix: EvalPoint<P::Scalar>,
302	pub partial_eval: MultilinearExtension<P>,
303}
304
305pub fn try_build_partial_eval<F: TowerField, P: PackedField<Scalar = F>>(
306	partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
307	oracles: &MultilinearOracleSet<F>,
308	id: OracleId,
309	suffix: &[F],
310	acc: &mut [P],
311	coeff: P,
312) -> bool {
313	match &oracles[id].variant {
314		crate::oracle::MultilinearPolyVariant::LinearCombination(lc) => {
315			for (poly_id, internal_coeff) in izip!(lc.polys(), lc.coefficients()) {
316				let new_coeff = coeff * P::broadcast(internal_coeff);
317
318				if !try_build_partial_eval(partial_evals, oracles, poly_id, suffix, acc, new_coeff)
319				{
320					return false;
321				}
322			}
323
324			if lc.offset() != F::zero() {
325				let offset = P::broadcast(lc.offset());
326				for acc in acc.iter_mut() {
327					*acc += offset;
328				}
329			}
330		}
331		_ => {
332			let mle = match partial_evals.get(id, suffix) {
333				Some(mle) => mle,
334				None => return false,
335			};
336			for (acc, eval) in acc.iter_mut().zip(mle.evals()) {
337				*acc += if coeff == P::one() {
338					*eval
339				} else {
340					*eval * coeff
341				};
342			}
343		}
344	};
345	true
346}
347
348/// shifted / packed oracle compute the projected MLE (i.e. the inner oracle evaluated on the
349/// projected eval_point)
350#[allow(clippy::type_complexity)]
351#[instrument(
352	skip_all,
353	name = "Evalcheck::calculate_projected_mles",
354	level = "debug"
355)]
356pub fn collect_projected_mles<F, P>(
357	metas: &[ProjectedBivariateMeta],
358	memoized_queries: &mut MemoizedData<P>,
359	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
360	oracles: &MultilinearOracleSet<F>,
361	witness_index: &MultilinearExtensionIndex<P>,
362	partial_evals: &mut EvalPointOracleIdMap<MultilinearExtension<P>, F>,
363) -> Result<(), Error>
364where
365	P: PackedField<Scalar = F>,
366	F: TowerField,
367{
368	let mut suffix_oracle_id = HashSet::new();
369
370	for (claim, meta) in projected_bivariate_claims.iter().zip(metas.iter()) {
371		if meta.projected_id.is_some() {
372			let suffix = &claim.eval_point[meta.projected_n_vars..];
373			suffix_oracle_id.insert((suffix, meta.inner_id));
374		}
375	}
376
377	let queries_to_memoize = suffix_oracle_id
378		.iter()
379		.copied()
380		.map(|(suffix, _)| suffix)
381		.collect::<Vec<_>>();
382
383	memoized_queries.memoize_query_par(queries_to_memoize)?;
384
385	let suffix_oracle_id = suffix_oracle_id.into_iter().collect::<Vec<_>>();
386
387	let new_partial_evals = suffix_oracle_id
388		.into_par_iter()
389		.map(|(suffix, inner_id)| {
390			let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
391
392			let query = memoized_queries
393				.full_query_readonly(suffix)
394				.ok_or(Error::MissingQuery)?;
395
396			if partial_evals.get(inner_id, suffix).is_some() {
397				return Ok(None);
398			}
399
400			let n_vars = inner_multilin.n_vars() - suffix.len();
401
402			let mut buffer = zeroed_vec(1 << n_vars.saturating_sub(P::LOG_WIDTH));
403
404			let is_built = try_build_partial_eval(
405				partial_evals,
406				oracles,
407				inner_id,
408				suffix,
409				&mut buffer,
410				P::one(),
411			);
412
413			let partial_eval = if is_built {
414				MultilinearExtension::new(n_vars, buffer).unwrap()
415			} else {
416				inner_multilin
417					.evaluate_partial_high(query.to_ref())
418					.map_err(Error::from)?
419			};
420
421			Ok(Some(OracleIdPartialEval {
422				id: inner_id,
423				suffix: suffix.into(),
424				partial_eval,
425			}))
426		})
427		.collect::<Result<Vec<Option<_>>, Error>>();
428
429	for OracleIdPartialEval {
430		id,
431		suffix,
432		partial_eval,
433	} in new_partial_evals?.into_iter().flatten()
434	{
435		partial_evals.insert(id, suffix, partial_eval)
436	}
437
438	Ok(())
439}
440
441/// Struct for memoizing tensor expansions of evaluation points and partial evaluations of
442/// multilinears
443#[allow(clippy::type_complexity)]
444pub struct MemoizedData<'a, P: PackedField> {
445	query: Vec<(Vec<P::Scalar>, MultilinearQuery<P>)>,
446	partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
447}
448
449impl<'a, P: PackedField> MemoizedData<'a, P> {
450	#[allow(clippy::new_without_default)]
451	pub fn new() -> Self {
452		Self {
453			query: Vec::new(),
454			partial_evals: EvalPointOracleIdMap::new(),
455		}
456	}
457
458	pub fn full_query(
459		&mut self,
460		eval_point: &[P::Scalar],
461	) -> Result<&MultilinearQuery<P>, binius_hal::Error> {
462		if let Some(index) = self
463			.query
464			.iter()
465			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
466		{
467			let (_, query) = &self.query[index];
468			return Ok(query);
469		}
470
471		let query = MultilinearQuery::expand(eval_point);
472		self.query.push((eval_point.to_vec(), query));
473
474		let (_, query) = self.query.last().expect("pushed query immediately above");
475		Ok(query)
476	}
477
478	/// Finds a `MultilinearQuery` corresponding to the given `eval_point`.
479	pub fn full_query_readonly(&self, eval_point: &[P::Scalar]) -> Option<&MultilinearQuery<P>> {
480		self.query
481			.iter()
482			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
483			.map(|index| {
484				let (_, query) = &self.query[index];
485				query
486			})
487	}
488
489	#[instrument(skip_all, name = "Evalcheck::memoize_query_par", level = "debug")]
490	pub fn memoize_query_par<'b>(
491		&mut self,
492		eval_points: impl IntoIterator<Item = &'b [P::Scalar]>,
493	) -> Result<(), binius_hal::Error> {
494		let deduplicated_eval_points = eval_points.into_iter().collect::<HashSet<_>>();
495
496		let new_queries = deduplicated_eval_points
497			.into_par_iter()
498			.filter(|ep| self.full_query_readonly(ep).is_none())
499			.map(|ep| {
500				let query = MultilinearQuery::<P>::expand(ep);
501				(ep.to_vec(), query)
502			})
503			.collect::<Vec<_>>();
504
505		self.query.extend(new_queries);
506
507		Ok(())
508	}
509
510	pub fn memoize_partial_evals(
511		&mut self,
512		metas: &[ProjectedBivariateMeta],
513		projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
514		oracles: &mut MultilinearOracleSet<P::Scalar>,
515		witness_index: &MultilinearExtensionIndex<'a, P>,
516	) where
517		P::Scalar: TowerField,
518	{
519		projected_bivariate_claims
520			.iter()
521			.zip(metas)
522			.for_each(|(claim, meta)| {
523				let inner_id = meta.inner_id;
524				if oracles[inner_id].variant.is_committed() && meta.projected_id.is_some() {
525					let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
526
527					let projected_id = meta.projected_id.expect("checked above");
528
529					let projected = witness_index
530						.get_multilin_poly(projected_id)
531						.expect("witness_index contains projected if projected_id exist");
532
533					self.partial_evals.insert(inner_id, eval_point, projected);
534				}
535			});
536	}
537
538	pub fn partial_eval(
539		&self,
540		id: OracleId,
541		eval_point: &[P::Scalar],
542	) -> Option<&MultilinearWitness<'a, P>> {
543		self.partial_evals.get(id, eval_point)
544	}
545}
546
547type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
548
549pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
550	witness: &MultilinearExtensionIndex<P>,
551	constraint_sets: Vec<SizedConstraintSet<F>>,
552	transcript: &mut ProverTranscript<Transcript>,
553	switchover_fn: impl Fn(usize) -> usize + 'static,
554	domain_factory: impl EvaluationDomainFactory<DomainField>,
555	backend: &Backend,
556) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
557where
558	P: PackedField<Scalar = F>
559		+ PackedExtension<F, PackedSubfield = P>
560		+ PackedExtension<DomainField>,
561	F: TowerField + ExtensionField<DomainField>,
562	DomainField: Field,
563	Transcript: Challenger,
564	Backend: ComputationBackend,
565{
566	let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
567		EvaluationOrder::HighToLow,
568		constraint_sets,
569		witness,
570		domain_factory,
571		&switchover_fn,
572		backend,
573	)?;
574
575	let batch_prover = front_loaded::BatchProver::new(provers, transcript)?;
576
577	let mut sumcheck_output = batch_prover.run(transcript)?;
578
579	// Reverse challenges since folding high-to-low
580	sumcheck_output.challenges.reverse();
581
582	let evalcheck_claims =
583		sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
584
585	Ok(evalcheck_claims)
586}
587
588#[allow(clippy::too_many_arguments)]
589pub fn prove_mlecheck_with_switchover<'a, F, P, DomainField, Transcript, Backend>(
590	witness: &MultilinearExtensionIndex<P>,
591	constraint_set: SizedConstraintSet<F>,
592	eq_ind_challenges: EvalPoint<F>,
593	memoized_data: &mut MemoizedData<'a, P>,
594	transcript: &mut ProverTranscript<Transcript>,
595	switchover_fn: impl Fn(usize) -> usize + 'static,
596	domain_factory: impl EvaluationDomainFactory<DomainField>,
597	backend: &Backend,
598) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
599where
600	P: PackedField<Scalar = F>
601		+ PackedExtension<F, PackedSubfield = P>
602		+ PackedExtension<DomainField>,
603	F: TowerField + ExtensionField<DomainField>,
604	DomainField: Field,
605	Transcript: Challenger,
606	Backend: ComputationBackend,
607{
608	let MLECheckProverWithMeta { prover, meta } = constraint_sets_mlecheck_prover_meta(
609		EvaluationOrder::HighToLow,
610		constraint_set,
611		eq_ind_challenges,
612		memoized_data,
613		witness,
614		domain_factory,
615		&switchover_fn,
616		backend,
617	)?;
618
619	let batch_prover = front_loaded::BatchProver::new(vec![prover], transcript)?;
620
621	let mut sumcheck_output = batch_prover.run(transcript)?;
622
623	// Reverse challenges since folding high-to-low
624	sumcheck_output.challenges.reverse();
625
626	// extract eq_ind_eval
627	sumcheck_output.multilinear_evals[0].pop();
628
629	let evalcheck_claims =
630		sumcheck::make_eval_claims(EvaluationOrder::HighToLow, vec![meta], sumcheck_output)?;
631
632	Ok(evalcheck_claims)
633}