binius_core/protocols/evalcheck/
subclaims.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! This module contains helpers to create bivariate sumcheck instances originating from:
4//!  * products with shift indicators (shifted virtual polynomials)
5//!  * products with tower basis (packed virtual polynomials)
6//!
7//! All of them have common traits:
8//!  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
9//!  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
10//!  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
11
12use std::{
13	collections::{HashMap, HashSet},
14	iter,
15};
16
17use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
18use binius_hal::{ComputationBackend, ComputationBackendExt};
19use binius_math::{
20	ArithCircuit, ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder,
21	MLEDirectAdapter, MultilinearExtension, MultilinearQuery,
22};
23use binius_maybe_rayon::prelude::*;
24use binius_utils::bail;
25use tracing::instrument;
26
27use super::{error::Error, evalcheck::EvalcheckMultilinearClaim, EvalPointOracleIdMap};
28use crate::{
29	fiat_shamir::Challenger,
30	oracle::{
31		CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
32		MultilinearOracleSet, MultilinearPolyVariant, OracleId, Packed, Shifted,
33	},
34	polynomial::MultivariatePoly,
35	protocols::sumcheck::{
36		self,
37		prove::{
38			front_loaded,
39			oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
40		},
41		Error as SumcheckError,
42	},
43	transcript::ProverTranscript,
44	transparent::{
45		eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
46	},
47	witness::{MultilinearExtensionIndex, MultilinearWitness},
48};
49
50/// Create oracles for the bivariate product of an inner oracle with shift indicator.
51///
52/// Projects to first `block_size()` vars.
53pub fn shifted_sumcheck_meta<F: TowerField>(
54	oracles: &mut MultilinearOracleSet<F>,
55	shifted: &Shifted,
56	eval_point: &[F],
57) -> Result<ProjectedBivariateMeta, Error> {
58	projected_bivariate_meta(
59		oracles,
60		shifted.id(),
61		shifted.block_size(),
62		eval_point,
63		|projected_eval_point| {
64			Ok(ShiftIndPartialEval::new(
65				shifted.block_size(),
66				shifted.shift_offset(),
67				shifted.shift_variant(),
68				projected_eval_point.to_vec(),
69			)?)
70		},
71	)
72}
73
74/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
75#[allow(clippy::too_many_arguments)]
76pub fn process_shifted_sumcheck<F, P>(
77	shifted: &Shifted,
78	meta: &ProjectedBivariateMeta,
79	eval_point: &[F],
80	eval: F,
81	witness_index: &mut MultilinearExtensionIndex<P>,
82	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
83	projected: Option<MultilinearExtension<P>>,
84) -> Result<(), Error>
85where
86	P: PackedField<Scalar = F>,
87	F: TowerField,
88{
89	process_projected_bivariate_witness(
90		witness_index,
91		meta,
92		eval_point,
93		|projected_eval_point| {
94			let shift_ind = ShiftIndPartialEval::new(
95				projected_eval_point.len(),
96				shifted.shift_offset(),
97				shifted.shift_variant(),
98				projected_eval_point.to_vec(),
99			)?;
100
101			let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
102			Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
103		},
104		projected,
105	)?;
106	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
107
108	Ok(())
109}
110
111/// Create oracles for the bivariate product of an inner oracle with the tower basis.
112///
113/// Projects to first `log_degree()` vars.
114/// Returns metadata object with oracle identifiers.
115pub fn packed_sumcheck_meta<F: TowerField>(
116	oracles: &mut MultilinearOracleSet<F>,
117	packed: &Packed,
118	eval_point: &[F],
119) -> Result<ProjectedBivariateMeta, Error> {
120	let n_vars = oracles.n_vars(packed.id());
121	let log_degree = packed.log_degree();
122	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
123
124	if log_degree > n_vars {
125		bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
126	}
127
128	// NB. projected_n_vars = 0 because eval_point length is log_degree less than inner n_vars
129	projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
130		Ok(TowerBasis::new(log_degree, binary_tower_level)?)
131	})
132}
133
134pub fn composite_mlecheck_meta<F: TowerField>(
135	oracles: &mut MultilinearOracleSet<F>,
136	eval_point: &[F],
137) -> Result<CompositeMLECheckMeta, Error> {
138	let eq_ind_id = oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?;
139	Ok(CompositeMLECheckMeta { eq_ind_id })
140}
141
142pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
143	meta: &ProjectedBivariateMeta,
144	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
145	n_vars: usize,
146	eval: F,
147) {
148	if n_vars >= constraint_builders.len() {
149		constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
150	}
151	let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
152	constraint_builders[n_vars].add_sumcheck(meta.oracle_ids(), bivariate_product.into(), eval);
153}
154
155pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
156	meta: CompositeMLECheckMeta,
157	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
158	comp: &CompositeMLE<F>,
159	eval: F,
160) {
161	let n_vars = comp.n_vars();
162	let mut oracle_ids = comp.inner().clone();
163	oracle_ids.push(meta.eq_ind_id); // eq
164
165	// Var(comp.n_polys()) corresponds to the eq MLE
166	let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithCircuit::var(comp.n_polys());
167	if n_vars >= constraint_builders.len() {
168		constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
169	}
170	constraint_builders[n_vars].add_sumcheck(oracle_ids, expr, eval);
171}
172
173/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
174#[allow(clippy::too_many_arguments)]
175pub fn process_packed_sumcheck<F, P>(
176	oracles: &MultilinearOracleSet<F>,
177	packed: &Packed,
178	meta: &ProjectedBivariateMeta,
179	eval_point: &[F],
180	eval: F,
181	witness_index: &mut MultilinearExtensionIndex<P>,
182	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
183	projected: Option<MultilinearExtension<P>>,
184) -> Result<(), Error>
185where
186	P: PackedField<Scalar = F>,
187	F: TowerField,
188{
189	let log_degree = packed.log_degree();
190	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
191
192	process_projected_bivariate_witness(
193		witness_index,
194		meta,
195		eval_point,
196		|_projected_eval_point| {
197			let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
198			let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
199			Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
200		},
201		projected,
202	)?;
203
204	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
205	Ok(())
206}
207
208#[derive(Debug, Clone, Copy)]
209pub struct CompositeMLECheckMeta {
210	pub eq_ind_id: OracleId,
211}
212
213/// Metadata about a sumcheck over a bivariate product of two multilinears.
214#[derive(Clone, Copy)]
215pub struct ProjectedBivariateMeta {
216	inner_id: OracleId,
217	projected_id: Option<OracleId>,
218	multiplier_id: OracleId,
219	projected_n_vars: usize,
220}
221
222impl ProjectedBivariateMeta {
223	pub fn oracle_ids(&self) -> [OracleId; 2] {
224		[
225			self.projected_id.unwrap_or(self.inner_id),
226			self.multiplier_id,
227		]
228	}
229}
230
231fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
232	oracles: &mut MultilinearOracleSet<F>,
233	inner_id: OracleId,
234	projected_n_vars: usize,
235	eval_point: &[F],
236	multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
237) -> Result<ProjectedBivariateMeta, Error> {
238	let inner = oracles.oracle(inner_id);
239
240	let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
241		let projected_id =
242			oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
243
244		(&eval_point[..projected_n_vars], Some(projected_id))
245	} else {
246		(eval_point, None)
247	};
248
249	let projected_n_vars = projected_eval_point.len();
250
251	let multiplier_id =
252		oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
253
254	let meta = ProjectedBivariateMeta {
255		inner_id,
256		projected_id,
257		multiplier_id,
258		projected_n_vars,
259	};
260
261	Ok(meta)
262}
263
264fn process_projected_bivariate_witness<'a, F, P>(
265	witness_index: &mut MultilinearExtensionIndex<'a, P>,
266	meta: &ProjectedBivariateMeta,
267	eval_point: &[F],
268	multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
269	projected: Option<MultilinearExtension<P>>,
270) -> Result<(), Error>
271where
272	P: PackedField<Scalar = F>,
273	F: TowerField,
274{
275	let &ProjectedBivariateMeta {
276		projected_id,
277		multiplier_id,
278		projected_n_vars,
279		..
280	} = meta;
281
282	let projected_eval_point = if let Some(projected_id) = projected_id {
283		witness_index.update_multilin_poly(vec![(
284			projected_id,
285			MLEDirectAdapter::from(
286				projected.expect("projected should exist if projected_id exist"),
287			)
288			.upcast_arc_dyn(),
289		)])?;
290
291		&eval_point[..projected_n_vars]
292	} else {
293		eval_point
294	};
295
296	let m = multiplier_witness_ctr(projected_eval_point)?;
297
298	if !witness_index.has(multiplier_id) {
299		witness_index.update_multilin_poly([(multiplier_id, m)])?;
300	}
301	Ok(())
302}
303
304/// shifted / packed oracle -> compute the projected MLE (i.e. the inner oracle evaluated on the projected eval_point)
305/// composite oracle -> None
306#[allow(clippy::type_complexity)]
307#[instrument(
308	skip_all,
309	name = "Evalcheck::calculate_projected_mles",
310	level = "debug"
311)]
312pub fn calculate_projected_mles<F, P, Backend>(
313	metas: &[ProjectedBivariateMeta],
314	memoized_queries: &mut MemoizedData<P, Backend>,
315	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
316	witness_index: &MultilinearExtensionIndex<P>,
317	backend: &Backend,
318) -> Result<Vec<Option<MultilinearExtension<P>>>, Error>
319where
320	P: PackedField<Scalar = F>,
321	F: TowerField,
322	Backend: ComputationBackend,
323{
324	let mut queries_to_memoize = Vec::new();
325	for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
326		queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
327	}
328	memoized_queries.memoize_query_par(queries_to_memoize, backend)?;
329
330	projected_bivariate_claims
331		.par_iter()
332		.zip(metas)
333		.map(|(claim, meta)| match meta.projected_id {
334			Some(_) => {
335				let inner_multilin = witness_index.get_multilin_poly(meta.inner_id)?;
336				let eval_point = &claim.eval_point[meta.projected_n_vars..];
337				let query = memoized_queries
338					.full_query_readonly(eval_point)
339					.ok_or(Error::MissingQuery)?;
340				Ok(Some(
341					backend
342						.evaluate_partial_high(&inner_multilin, query.to_ref())
343						.map_err(Error::from)?,
344				))
345			}
346			_ => Ok(None),
347		})
348		.collect::<Result<Vec<Option<_>>, Error>>()
349}
350
351/// Each composite oracle induces a new eq oracle, for which we need to fill the witness
352pub fn fill_eq_witness_for_composites<F, P, Backend>(
353	metas: &[CompositeMLECheckMeta],
354	memoized_queries: &mut MemoizedData<P, Backend>,
355	composite_mle_claims: &[EvalcheckMultilinearClaim<F>],
356	witness_index: &mut MultilinearExtensionIndex<P>,
357	backend: &Backend,
358) -> Result<(), Error>
359where
360	P: PackedField<Scalar = F>,
361	F: TowerField,
362	Backend: ComputationBackend,
363{
364	let dedup_eval_points = composite_mle_claims
365		.iter()
366		.map(|claim| claim.eval_point.as_ref())
367		.collect::<HashSet<_>>();
368
369	memoized_queries.memoize_query_par(dedup_eval_points.iter().copied(), backend)?;
370
371	let eq_indicators = dedup_eval_points
372		.into_iter()
373		.map(|eval_point| {
374			let mle = MLEDirectAdapter::from(MultilinearExtension::new(
375				eval_point.len(),
376				memoized_queries
377					.full_query_readonly(eval_point)
378					.expect("computed above")
379					.expansion()
380					.to_vec(),
381			)?)
382			.upcast_arc_dyn();
383			Ok((eval_point, mle))
384		})
385		.collect::<Result<HashMap<_, _>, Error>>()?;
386
387	for (claim, meta) in iter::zip(composite_mle_claims, metas) {
388		let eq_ind = eq_indicators
389			.get(claim.eval_point.as_ref())
390			.expect("was added above");
391
392		witness_index.update_multilin_poly(vec![(meta.eq_ind_id, eq_ind.clone())])?;
393	}
394
395	Ok(())
396}
397
398/// Struct for memoizing tensor expansions of evaluation points and partial evaluations of multilinears
399#[allow(clippy::type_complexity)]
400pub struct MemoizedData<'a, P: PackedField, Backend: ComputationBackend> {
401	query: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
402	partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
403}
404
405impl<'a, P: PackedField, Backend: ComputationBackend> MemoizedData<'a, P, Backend> {
406	#[allow(clippy::new_without_default)]
407	pub fn new() -> Self {
408		Self {
409			query: Vec::new(),
410			partial_evals: EvalPointOracleIdMap::new(),
411		}
412	}
413
414	pub fn full_query(
415		&mut self,
416		eval_point: &[P::Scalar],
417		backend: &Backend,
418	) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
419		if let Some(index) = self
420			.query
421			.iter()
422			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
423		{
424			let (_, ref query) = &self.query[index];
425			return Ok(query);
426		}
427
428		let query = backend.multilinear_query(eval_point)?;
429		self.query.push((eval_point.to_vec(), query));
430
431		let (_, ref query) = self.query.last().expect("pushed query immediately above");
432		Ok(query)
433	}
434
435	/// Finds a `MultilinearQuery` corresponding to the given `eval_point`.
436	pub fn full_query_readonly(
437		&self,
438		eval_point: &[P::Scalar],
439	) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
440		self.query
441			.iter()
442			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
443			.map(|index| {
444				let (_, ref query) = &self.query[index];
445				query
446			})
447	}
448
449	#[instrument(skip_all, name = "Evalcheck::memoize_query_par", level = "debug")]
450	pub fn memoize_query_par<'b>(
451		&mut self,
452		eval_points: impl IntoIterator<Item = &'b [P::Scalar]>,
453		backend: &Backend,
454	) -> Result<(), binius_hal::Error> {
455		let deduplicated_eval_points = eval_points.into_iter().collect::<HashSet<_>>();
456
457		let new_queries = deduplicated_eval_points
458			.into_par_iter()
459			.filter(|ep| self.full_query_readonly(ep).is_none())
460			.map(|ep| {
461				backend
462					.multilinear_query::<P>(ep)
463					.map(|res| (ep.to_vec(), res))
464			})
465			.collect::<Result<Vec<_>, binius_hal::Error>>()?;
466
467		self.query.extend(new_queries);
468
469		Ok(())
470	}
471
472	pub fn memoize_partial_evals(
473		&mut self,
474		metas: &[ProjectedBivariateMeta],
475		projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
476		oracles: &mut MultilinearOracleSet<P::Scalar>,
477		witness_index: &MultilinearExtensionIndex<'a, P>,
478	) where
479		P::Scalar: TowerField,
480	{
481		projected_bivariate_claims
482			.iter()
483			.zip(metas)
484			.for_each(|(claim, meta)| {
485				let inner_id = meta.inner_id;
486				if matches!(oracles.oracle(inner_id).variant, MultilinearPolyVariant::Committed)
487					&& meta.projected_id.is_some()
488				{
489					let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
490
491					let projected_id = meta.projected_id.expect("checked above");
492
493					let projected = witness_index
494						.get_multilin_poly(projected_id)
495						.expect("witness_index contains projected if projected_id exist");
496
497					self.partial_evals.insert(inner_id, eval_point, projected);
498				}
499			});
500	}
501
502	pub fn partial_eval(
503		&self,
504		id: OracleId,
505		eval_point: &[P::Scalar],
506	) -> Option<&MultilinearWitness<'a, P>> {
507		self.partial_evals.get(id, eval_point)
508	}
509}
510
511type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
512
513pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
514	witness: &MultilinearExtensionIndex<P>,
515	constraint_sets: Vec<ConstraintSet<F>>,
516	transcript: &mut ProverTranscript<Transcript>,
517	switchover_fn: impl Fn(usize) -> usize + 'static,
518	domain_factory: impl EvaluationDomainFactory<DomainField>,
519	backend: &Backend,
520) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
521where
522	P: PackedField<Scalar = F>
523		+ PackedExtension<F, PackedSubfield = P>
524		+ PackedExtension<DomainField>,
525	F: TowerField + ExtensionField<DomainField>,
526	DomainField: Field,
527	Transcript: Challenger,
528	Backend: ComputationBackend,
529{
530	let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
531		EvaluationOrder::HighToLow,
532		constraint_sets,
533		witness,
534		domain_factory,
535		&switchover_fn,
536		backend,
537	)?;
538
539	let batch_prover = front_loaded::BatchProver::new(provers, transcript)?;
540
541	let mut sumcheck_output = batch_prover.run(transcript)?;
542
543	// Reverse challenges since folding high-to-low
544	sumcheck_output.challenges.reverse();
545
546	let evalcheck_claims =
547		sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
548
549	Ok(evalcheck_claims)
550}
551
552#[derive(Clone)]
553pub enum SumcheckClaims<F: Field> {
554	Projected(EvalcheckMultilinearClaim<F>),
555	Composite(EvalcheckMultilinearClaim<F>),
556}