binius_core/protocols/evalcheck/
subclaims.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! This module contains helpers to create bivariate sumcheck instances originating from:
4//!  * products with shift indicators (shifted virtual polynomials)
5//!  * products with tower basis (packed virtual polynomials)
6//!
7//! All of them have common traits:
8//!  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
9//!  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
10//!  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
11
12use std::collections::{HashMap, HashSet};
13
14use binius_field::{
15	ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
16};
17use binius_hal::{ComputationBackend, ComputationBackendExt};
18use binius_math::{
19	ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
20	MultilinearExtension, MultilinearQuery,
21};
22use binius_maybe_rayon::prelude::*;
23use binius_utils::bail;
24
25use super::{error::Error, evalcheck::EvalcheckMultilinearClaim, EvalPointOracleIdMap};
26use crate::{
27	fiat_shamir::Challenger,
28	oracle::{
29		CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
30		MultilinearOracleSet, MultilinearPolyVariant, OracleId, Packed, Shifted,
31	},
32	polynomial::MultivariatePoly,
33	protocols::sumcheck::{
34		self,
35		prove::oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
36		Error as SumcheckError,
37	},
38	transcript::ProverTranscript,
39	transparent::{
40		eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
41	},
42	witness::{MultilinearExtensionIndex, MultilinearWitness},
43};
44
45/// Create oracles for the bivariate product of an inner oracle with shift indicator.
46///
47/// Projects to first `block_size()` vars.
48pub fn shifted_sumcheck_meta<F: TowerField>(
49	oracles: &mut MultilinearOracleSet<F>,
50	shifted: &Shifted,
51	eval_point: &[F],
52) -> Result<ProjectedBivariateMeta, Error> {
53	projected_bivariate_meta(
54		oracles,
55		shifted.id(),
56		shifted.block_size(),
57		eval_point,
58		|projected_eval_point| {
59			Ok(ShiftIndPartialEval::new(
60				shifted.block_size(),
61				shifted.shift_offset(),
62				shifted.shift_variant(),
63				projected_eval_point.to_vec(),
64			)?)
65		},
66	)
67}
68
69/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
70#[allow(clippy::too_many_arguments)]
71pub fn process_shifted_sumcheck<F, P>(
72	shifted: &Shifted,
73	meta: &ProjectedBivariateMeta,
74	eval_point: &[F],
75	eval: F,
76	witness_index: &mut MultilinearExtensionIndex<P>,
77	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
78	projected: Option<MultilinearExtension<P>>,
79) -> Result<(), Error>
80where
81	P: PackedFieldIndexable<Scalar = F>,
82	F: TowerField,
83{
84	process_projected_bivariate_witness(
85		witness_index,
86		meta,
87		eval_point,
88		|projected_eval_point| {
89			let shift_ind = ShiftIndPartialEval::new(
90				projected_eval_point.len(),
91				shifted.shift_offset(),
92				shifted.shift_variant(),
93				projected_eval_point.to_vec(),
94			)?;
95
96			let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
97			Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
98		},
99		projected,
100	)?;
101	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
102
103	Ok(())
104}
105
106/// Create oracles for the bivariate product of an inner oracle with the tower basis.
107///
108/// Projects to first `log_degree()` vars.
109/// Returns metadata object with oracle identifiers.
110pub fn packed_sumcheck_meta<F: TowerField>(
111	oracles: &mut MultilinearOracleSet<F>,
112	packed: &Packed,
113	eval_point: &[F],
114) -> Result<ProjectedBivariateMeta, Error> {
115	let n_vars = oracles.n_vars(packed.id());
116	let log_degree = packed.log_degree();
117	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
118
119	if log_degree > n_vars {
120		bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
121	}
122
123	// NB. projected_n_vars = 0 because eval_point length is log_degree less than inner n_vars
124	projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
125		Ok(TowerBasis::new(log_degree, binary_tower_level)?)
126	})
127}
128
129pub fn composite_sumcheck_meta<F: TowerField>(
130	oracles: &mut MultilinearOracleSet<F>,
131	eval_point: &[F],
132) -> Result<ProjectedBivariateMeta, Error> {
133	Ok(ProjectedBivariateMeta {
134		multiplier_id: oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?,
135		inner_id: None,
136		projected_id: None,
137		// not used in case of composite
138		projected_n_vars: 0,
139	})
140}
141
142pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
143	meta: &ProjectedBivariateMeta,
144	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
145	n_vars: usize,
146	eval: F,
147) {
148	if n_vars > constraint_builders.len() {
149		constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
150	}
151	let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
152	constraint_builders[n_vars - 1].add_sumcheck(meta.oracle_ids(), bivariate_product, eval);
153}
154
155pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
156	meta: &ProjectedBivariateMeta,
157	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
158	comp: &CompositeMLE<F>,
159	eval: F,
160) {
161	let n_vars = comp.n_vars();
162	let mut oracle_ids = comp.inner().clone();
163	oracle_ids.push(meta.multiplier_id); // eq
164
165	// Var(comp.n_polys()) corresponds to the eq MLE (meta.multiplier_id)
166	let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithExpr::Var(comp.n_polys());
167	if n_vars > constraint_builders.len() {
168		constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
169	}
170	constraint_builders[n_vars - 1].add_sumcheck(oracle_ids, expr, eval);
171}
172
173/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
174#[allow(clippy::too_many_arguments)]
175pub fn process_packed_sumcheck<F, P>(
176	oracles: &MultilinearOracleSet<F>,
177	packed: &Packed,
178	meta: &ProjectedBivariateMeta,
179	eval_point: &[F],
180	eval: F,
181	witness_index: &mut MultilinearExtensionIndex<P>,
182	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
183	projected: Option<MultilinearExtension<P>>,
184) -> Result<(), Error>
185where
186	P: PackedField<Scalar = F>,
187	F: TowerField,
188{
189	let log_degree = packed.log_degree();
190	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
191
192	process_projected_bivariate_witness(
193		witness_index,
194		meta,
195		eval_point,
196		|_projected_eval_point| {
197			let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
198			let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
199			Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
200		},
201		projected,
202	)?;
203
204	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
205	Ok(())
206}
207
208#[derive(Clone, Copy)]
209pub struct ProjectedBivariateMeta {
210	/// `Some` if shifted / packed, `None` if composite
211	inner_id: Option<OracleId>,
212	projected_id: Option<OracleId>,
213	multiplier_id: OracleId,
214	projected_n_vars: usize,
215}
216
217impl ProjectedBivariateMeta {
218	pub fn oracle_ids(&self) -> [OracleId; 2] {
219		[
220			self.projected_id.unwrap_or_else(|| {
221				self.inner_id
222					.expect("oracle_ids() is only defined for shifted / packed")
223			}),
224			self.multiplier_id,
225		]
226	}
227}
228
229fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
230	oracles: &mut MultilinearOracleSet<F>,
231	inner_id: OracleId,
232	projected_n_vars: usize,
233	eval_point: &[F],
234	multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
235) -> Result<ProjectedBivariateMeta, Error> {
236	let inner = oracles.oracle(inner_id);
237
238	let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
239		let projected_id =
240			oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
241
242		(&eval_point[..projected_n_vars], Some(projected_id))
243	} else {
244		(eval_point, None)
245	};
246
247	let projected_n_vars = projected_eval_point.len();
248
249	let multiplier_id =
250		oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
251
252	let meta = ProjectedBivariateMeta {
253		inner_id: Some(inner_id),
254		projected_id,
255		multiplier_id,
256		projected_n_vars,
257	};
258
259	Ok(meta)
260}
261
262fn process_projected_bivariate_witness<'a, F, P>(
263	witness_index: &mut MultilinearExtensionIndex<'a, P>,
264	meta: &ProjectedBivariateMeta,
265	eval_point: &[F],
266	multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
267	projected: Option<MultilinearExtension<P>>,
268) -> Result<(), Error>
269where
270	P: PackedField<Scalar = F>,
271	F: TowerField,
272{
273	let &ProjectedBivariateMeta {
274		projected_id,
275		multiplier_id,
276		projected_n_vars,
277		..
278	} = meta;
279
280	let projected_eval_point = if let Some(projected_id) = projected_id {
281		witness_index.update_multilin_poly(vec![(
282			projected_id,
283			MLEDirectAdapter::from(
284				projected.expect("projected should exist if projected_id exist"),
285			)
286			.upcast_arc_dyn(),
287		)])?;
288
289		&eval_point[..projected_n_vars]
290	} else {
291		eval_point
292	};
293
294	let m = multiplier_witness_ctr(projected_eval_point)?;
295
296	if !witness_index.has(multiplier_id) {
297		witness_index.update_multilin_poly([(multiplier_id, m)])?;
298	}
299	Ok(())
300}
301
302/// shifted / packed oracle -> compute the projected MLE (i.e. the inner oracle evaluated on the projected eval_point)
303/// composite oracle -> None
304#[allow(clippy::type_complexity)]
305pub fn calculate_projected_mles<F, P, Backend>(
306	metas: &[ProjectedBivariateMeta],
307	memoized_queries: &mut MemoizedData<P, Backend>,
308	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
309	witness_index: &MultilinearExtensionIndex<P>,
310	backend: &Backend,
311) -> Result<Vec<Option<MultilinearExtension<P>>>, Error>
312where
313	P: PackedField<Scalar = F>,
314	F: TowerField,
315	Backend: ComputationBackend,
316{
317	let mut queries_to_memoize = Vec::new();
318	for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
319		if meta.inner_id.is_some() {
320			// packed / shifted
321			queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
322		}
323	}
324	memoized_queries.memoize_query_par(&queries_to_memoize, backend)?;
325
326	projected_bivariate_claims
327		.par_iter()
328		.zip(metas)
329		.map(|(claim, meta)| match (meta.inner_id, meta.projected_id) {
330			(Some(inner_id), Some(_)) => {
331				let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
332				let eval_point = &claim.eval_point[meta.projected_n_vars..];
333				let query = memoized_queries
334					.full_query_readonly(eval_point)
335					.ok_or(Error::MissingQuery)?;
336				Ok(Some(
337					backend
338						.evaluate_partial_high(&inner_multilin, query.to_ref())
339						.map_err(Error::from)?,
340				))
341			}
342			_ => Ok(None),
343		})
344		.collect::<Result<Vec<Option<_>>, Error>>()
345}
346
347/// Each composite oracle induces a new eq oracle, for which we need to fill the witness
348pub fn fill_eq_witness_for_composites<F, P, Backend>(
349	metas: &[ProjectedBivariateMeta],
350	memoized_queries: &mut MemoizedData<P, Backend>,
351	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
352	witness_index: &mut MultilinearExtensionIndex<P>,
353	backend: &Backend,
354) -> Result<(), Error>
355where
356	P: PackedField<Scalar = F>,
357	F: TowerField,
358	Backend: ComputationBackend,
359{
360	let dedup_eval_points = metas
361		.iter()
362		.zip(projected_bivariate_claims)
363		.filter(|(meta, _)| meta.inner_id.is_none())
364		.map(|(_, claim)| claim.eval_point.as_ref())
365		.collect::<HashSet<_>>();
366
367	memoized_queries
368		.memoize_query_par(&dedup_eval_points.iter().copied().collect::<Vec<_>>(), backend)?;
369
370	let eq_indicators = dedup_eval_points
371		.into_iter()
372		.map(|eval_point| {
373			let mle = MLEDirectAdapter::from(MultilinearExtension::from_values(
374				memoized_queries
375					.full_query_readonly(eval_point)
376					.expect("computed above")
377					.expansion()
378					.to_vec(),
379			)?)
380			.upcast_arc_dyn();
381			Ok((eval_point, mle))
382		})
383		.collect::<Result<HashMap<_, _>, Error>>()?;
384
385	for (meta, claim) in metas
386		.iter()
387		.zip(projected_bivariate_claims)
388		.filter(|(meta, _)| meta.inner_id.is_none())
389	{
390		let eq_ind = eq_indicators
391			.get(claim.eval_point.as_ref())
392			.expect("was added above");
393
394		witness_index.update_multilin_poly(vec![(meta.multiplier_id, eq_ind.clone())])?;
395	}
396
397	Ok(())
398}
399
400#[allow(clippy::type_complexity)]
401pub struct MemoizedData<'a, P: PackedField, Backend: ComputationBackend> {
402	query: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
403	partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
404}
405
406impl<'a, P: PackedField, Backend: ComputationBackend> MemoizedData<'a, P, Backend> {
407	#[allow(clippy::new_without_default)]
408	pub fn new() -> Self {
409		Self {
410			query: Vec::new(),
411			partial_evals: EvalPointOracleIdMap::new(),
412		}
413	}
414
415	pub fn full_query(
416		&mut self,
417		eval_point: &[P::Scalar],
418		backend: &Backend,
419	) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
420		if let Some(index) = self
421			.query
422			.iter()
423			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
424		{
425			let (_, ref query) = &self.query[index];
426			return Ok(query);
427		}
428
429		let query = backend.multilinear_query(eval_point)?;
430		self.query.push((eval_point.to_vec(), query));
431
432		let (_, ref query) = self.query.last().expect("pushed query immediately above");
433		Ok(query)
434	}
435
436	/// Finds a `MultilinearQuery` corresponding to the given `eval_point`.
437	pub fn full_query_readonly(
438		&self,
439		eval_point: &[P::Scalar],
440	) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
441		self.query
442			.iter()
443			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
444			.map(|index| {
445				let (_, ref query) = &self.query[index];
446				query
447			})
448	}
449
450	pub fn memoize_query_par(
451		&mut self,
452		eval_points: &[&[P::Scalar]],
453		backend: &Backend,
454	) -> Result<(), binius_hal::Error> {
455		let deduplicated_eval_points = eval_points.iter().collect::<HashSet<_>>();
456
457		let new_queries = deduplicated_eval_points
458			.into_par_iter()
459			.filter(|ep| self.full_query_readonly(ep).is_none())
460			.map(|ep| {
461				backend
462					.multilinear_query::<P>(ep)
463					.map(|res| (ep.to_vec(), res))
464			})
465			.collect::<Result<Vec<_>, binius_hal::Error>>()?;
466
467		self.query.extend(new_queries);
468
469		Ok(())
470	}
471
472	pub fn memoize_partial_evals(
473		&mut self,
474		metas: &[ProjectedBivariateMeta],
475		projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
476		oracles: &mut MultilinearOracleSet<P::Scalar>,
477		witness_index: &MultilinearExtensionIndex<'a, P>,
478	) where
479		P::Scalar: TowerField,
480	{
481		projected_bivariate_claims
482			.iter()
483			.zip(metas)
484			.filter(|(_, meta)| meta.inner_id.is_some())
485			.for_each(|(claim, meta)| {
486				let inner_id = meta.inner_id.expect("filtered by Some");
487				if matches!(oracles.oracle(inner_id).variant, MultilinearPolyVariant::Committed)
488					&& meta.projected_id.is_some()
489				{
490					let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
491
492					let projected_id = meta.projected_id.expect("checked above");
493
494					let projected = witness_index
495						.get_multilin_poly(projected_id)
496						.expect("witness_index contains projected if projected_id exist");
497
498					self.partial_evals.insert(inner_id, eval_point, projected);
499				}
500			});
501	}
502
503	pub fn partial_eval(
504		&self,
505		id: OracleId,
506		eval_point: &[P::Scalar],
507	) -> Option<&MultilinearWitness<'a, P>> {
508		self.partial_evals.get(id, eval_point)
509	}
510}
511
512type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
513
514pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
515	witness: &MultilinearExtensionIndex<P>,
516	constraint_sets: Vec<ConstraintSet<F>>,
517	transcript: &mut ProverTranscript<Transcript>,
518	switchover_fn: impl Fn(usize) -> usize + 'static,
519	domain_factory: impl EvaluationDomainFactory<DomainField>,
520	backend: &Backend,
521) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
522where
523	P: PackedField<Scalar = F>
524		+ PackedExtension<F, PackedSubfield = P>
525		+ PackedExtension<DomainField>,
526	F: TowerField + ExtensionField<DomainField>,
527	DomainField: Field,
528	Transcript: Challenger,
529	Backend: ComputationBackend,
530{
531	let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
532		EvaluationOrder::HighToLow,
533		constraint_sets,
534		witness,
535		domain_factory,
536		&switchover_fn,
537		backend,
538	)?;
539
540	let sumcheck_output = sumcheck::batch_prove(provers, transcript)?;
541
542	let evalcheck_claims =
543		sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
544
545	Ok(evalcheck_claims)
546}