binius_core/protocols/evalcheck/
subclaims.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! This module contains helpers to create bivariate sumcheck instances originating from:
4//!  * products with shift indicators (shifted virtual polynomials)
5//!  * products with tower basis (packed virtual polynomials)
6//!
7//! All of them have common traits:
8//!  * they are always a product of two multilins (composition polynomial is `BivariateProduct`)
9//!  * one multilin (the multiplier) is transparent (`shift_ind`, `eq_ind`, or tower basis)
10//!  * other multilin is a projection of one of the evalcheck claim multilins to its first variables
11
12use std::collections::{HashMap, HashSet};
13
14use binius_field::{
15	as_packed_field::{PackScalar, PackedType},
16	underlier::UnderlierType,
17	ExtensionField, Field, PackedField, PackedFieldIndexable, TowerField,
18};
19use binius_hal::{ComputationBackend, ComputationBackendExt};
20use binius_math::{
21	ArithExpr, CompositionPoly, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension,
22	MultilinearQuery,
23};
24use binius_maybe_rayon::prelude::*;
25use binius_utils::bail;
26
27use super::{error::Error, evalcheck::EvalcheckMultilinearClaim};
28use crate::{
29	fiat_shamir::Challenger,
30	oracle::{
31		CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
32		MultilinearOracleSet, OracleId, Packed, ProjectionVariant, Shifted,
33	},
34	polynomial::MultivariatePoly,
35	protocols::sumcheck::{
36		self,
37		prove::oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
38		Error as SumcheckError,
39	},
40	transcript::ProverTranscript,
41	transparent::{
42		eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
43	},
44	witness::{MultilinearExtensionIndex, MultilinearWitness},
45};
46
47/// Create oracles for the bivariate product of an inner oracle with shift indicator.
48///
49/// Projects to first `block_size()` vars.
50pub fn shifted_sumcheck_meta<F: TowerField>(
51	oracles: &mut MultilinearOracleSet<F>,
52	shifted: &Shifted,
53	eval_point: &[F],
54) -> Result<ProjectedBivariateMeta, Error> {
55	projected_bivariate_meta(
56		oracles,
57		shifted.id(),
58		shifted.block_size(),
59		eval_point,
60		|projected_eval_point| {
61			Ok(ShiftIndPartialEval::new(
62				shifted.block_size(),
63				shifted.shift_offset(),
64				shifted.shift_variant(),
65				projected_eval_point.to_vec(),
66			)?)
67		},
68	)
69}
70
71/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
72#[allow(clippy::too_many_arguments)]
73pub fn process_shifted_sumcheck<U, F>(
74	shifted: &Shifted,
75	meta: ProjectedBivariateMeta,
76	eval_point: &[F],
77	eval: F,
78	witness_index: &mut MultilinearExtensionIndex<U, F>,
79	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
80	projected: MultilinearExtension<PackedType<U, F>>,
81) -> Result<(), Error>
82where
83	PackedType<U, F>: PackedFieldIndexable,
84	U: UnderlierType + PackScalar<F>,
85	F: TowerField,
86{
87	process_projected_bivariate_witness(
88		witness_index,
89		meta,
90		eval_point,
91		|projected_eval_point| {
92			let shift_ind = ShiftIndPartialEval::new(
93				projected_eval_point.len(),
94				shifted.shift_offset(),
95				shifted.shift_variant(),
96				projected_eval_point.to_vec(),
97			)?;
98
99			let shift_ind_mle = shift_ind.multilinear_extension::<PackedType<U, F>>()?;
100			Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
101		},
102		projected,
103	)?;
104	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
105
106	Ok(())
107}
108
109/// Create oracles for the bivariate product of an inner oracle with the tower basis.
110///
111/// Projects to first `log_degree()` vars.
112/// Returns metadata object with oracle identifiers.
113pub fn packed_sumcheck_meta<F: TowerField>(
114	oracles: &mut MultilinearOracleSet<F>,
115	packed: &Packed,
116	eval_point: &[F],
117) -> Result<ProjectedBivariateMeta, Error> {
118	let n_vars = oracles.n_vars(packed.id());
119	let log_degree = packed.log_degree();
120	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
121
122	if log_degree > n_vars {
123		bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
124	}
125
126	// NB. projected_n_vars = 0 because eval_point length is log_degree less than inner n_vars
127	projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
128		Ok(TowerBasis::new(log_degree, binary_tower_level)?)
129	})
130}
131
132pub fn composite_sumcheck_meta<F: TowerField>(
133	oracles: &mut MultilinearOracleSet<F>,
134	eval_point: &[F],
135) -> Result<ProjectedBivariateMeta, Error> {
136	Ok(ProjectedBivariateMeta {
137		multiplier_id: oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?,
138		inner_id: None,
139		projected_id: None,
140		projected_n_vars: eval_point.len(),
141	})
142}
143
144pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
145	meta: ProjectedBivariateMeta,
146	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
147	n_vars: usize,
148	eval: F,
149) {
150	if n_vars > constraint_builders.len() {
151		constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
152	}
153	let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
154	constraint_builders[n_vars - 1].add_sumcheck(meta.oracle_ids(), bivariate_product, eval);
155}
156
157pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
158	meta: ProjectedBivariateMeta,
159	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
160	comp: &CompositeMLE<F>,
161	eval: F,
162) {
163	let n_vars = comp.n_vars();
164	let mut oracle_ids = comp.inner().clone();
165	oracle_ids.push(meta.multiplier_id); // eq
166
167	// Var(comp.n_polys()) corresponds to the eq MLE (meta.multiplier_id)
168	let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithExpr::Var(comp.n_polys());
169	if n_vars > constraint_builders.len() {
170		constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
171	}
172	constraint_builders[n_vars - 1].add_sumcheck(oracle_ids, expr, eval);
173}
174
175/// Creates bivariate witness and adds them to the witness index, and add bivariate sumcheck constraint to the [`ConstraintSetBuilder`]
176#[allow(clippy::too_many_arguments)]
177pub fn process_packed_sumcheck<U, F>(
178	oracles: &MultilinearOracleSet<F>,
179	packed: &Packed,
180	meta: ProjectedBivariateMeta,
181	eval_point: &[F],
182	eval: F,
183	witness_index: &mut MultilinearExtensionIndex<U, F>,
184	constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
185	projected: MultilinearExtension<PackedType<U, F>>,
186) -> Result<(), Error>
187where
188	U: UnderlierType + PackScalar<F>,
189	F: TowerField,
190{
191	let log_degree = packed.log_degree();
192	let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
193
194	process_projected_bivariate_witness(
195		witness_index,
196		meta,
197		eval_point,
198		|_projected_eval_point| {
199			let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
200			let tower_basis_mle = tower_basis.multilinear_extension::<PackedType<U, F>>()?;
201			Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
202		},
203		projected,
204	)?;
205
206	add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
207	Ok(())
208}
209
210#[derive(Clone, Copy)]
211pub struct ProjectedBivariateMeta {
212	/// `Some` if shifted / packed, `None` if composite
213	inner_id: Option<OracleId>,
214	projected_id: Option<OracleId>,
215	multiplier_id: OracleId,
216	projected_n_vars: usize,
217}
218
219impl ProjectedBivariateMeta {
220	pub fn oracle_ids(&self) -> [OracleId; 2] {
221		[
222			self.projected_id.unwrap_or_else(|| {
223				self.inner_id
224					.expect("oracle_ids() is only defined for shifted / packed")
225			}),
226			self.multiplier_id,
227		]
228	}
229}
230
231fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
232	oracles: &mut MultilinearOracleSet<F>,
233	inner_id: OracleId,
234	projected_n_vars: usize,
235	eval_point: &[F],
236	multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
237) -> Result<ProjectedBivariateMeta, Error> {
238	let inner = oracles.oracle(inner_id);
239
240	let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
241		let projected_id = oracles.add_projected(
242			inner_id,
243			eval_point[projected_n_vars..].to_vec(),
244			ProjectionVariant::LastVars,
245		)?;
246
247		(&eval_point[..projected_n_vars], Some(projected_id))
248	} else {
249		(eval_point, None)
250	};
251
252	let projected_n_vars = projected_eval_point.len();
253
254	let multiplier_id =
255		oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
256
257	let meta = ProjectedBivariateMeta {
258		inner_id: Some(inner_id),
259		projected_id,
260		multiplier_id,
261		projected_n_vars,
262	};
263
264	Ok(meta)
265}
266
267fn process_projected_bivariate_witness<'a, U, F>(
268	witness_index: &mut MultilinearExtensionIndex<'a, U, F>,
269	meta: ProjectedBivariateMeta,
270	eval_point: &[F],
271	multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, PackedType<U, F>>, Error>,
272	projected: MultilinearExtension<PackedType<U, F>>,
273) -> Result<(), Error>
274where
275	U: UnderlierType + PackScalar<F>,
276	F: TowerField,
277{
278	let ProjectedBivariateMeta {
279		projected_id,
280		multiplier_id,
281		projected_n_vars,
282		..
283	} = meta;
284
285	let projected_eval_point = if let Some(projected_id) = projected_id {
286		witness_index.update_multilin_poly(vec![(
287			projected_id,
288			MLEDirectAdapter::from(projected).upcast_arc_dyn(),
289		)])?;
290
291		&eval_point[..projected_n_vars]
292	} else {
293		eval_point
294	};
295
296	let m = multiplier_witness_ctr(projected_eval_point)?;
297
298	if !witness_index.has(multiplier_id) {
299		witness_index.update_multilin_poly(vec![(multiplier_id, m)])?;
300	}
301	Ok(())
302}
303
304/// shifted / packed oracle -> compute the projected MLE (i.e. the inner oracle evaluated on the projected eval_point)
305/// composite oracle -> None
306#[allow(clippy::type_complexity)]
307pub fn calculate_projected_mles<U, F, Backend>(
308	metas: &[ProjectedBivariateMeta],
309	memoized_queries: &mut MemoizedQueries<PackedType<U, F>, Backend>,
310	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
311	witness_index: &MultilinearExtensionIndex<U, F>,
312	backend: &Backend,
313) -> Result<Vec<Option<MultilinearExtension<PackedType<U, F>>>>, Error>
314where
315	U: UnderlierType + PackScalar<F>,
316	F: TowerField,
317	Backend: ComputationBackend,
318{
319	let mut queries_to_memoize = Vec::new();
320	for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
321		if meta.inner_id.is_some() {
322			// packed / shifted
323			queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
324		}
325	}
326	memoized_queries.memoize_query_par(&queries_to_memoize, backend)?;
327
328	projected_bivariate_claims
329		.par_iter()
330		.zip(metas)
331		.map(|(claim, meta)| {
332			match meta.inner_id {
333				Some(inner_id) => {
334					{
335						// packed / shifted
336						let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
337						let eval_point = &claim.eval_point[meta.projected_n_vars..];
338						let query = memoized_queries
339							.full_query_readonly(eval_point)
340							.ok_or(Error::MissingQuery)?;
341						Ok(Some(
342							backend
343								.evaluate_partial_high(&inner_multilin, query.to_ref())
344								.map_err(Error::from)?,
345						))
346					}
347				}
348				None => Ok(None), // composite
349			}
350		})
351		.collect::<Result<Vec<Option<_>>, Error>>()
352}
353
354/// Each composite oracle induces a new eq oracle, for which we need to fill the witness
355pub fn fill_eq_witness_for_composites<U, F, Backend>(
356	metas: &[ProjectedBivariateMeta],
357	memoized_queries: &mut MemoizedQueries<PackedType<U, F>, Backend>,
358	projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
359	witness_index: &mut MultilinearExtensionIndex<U, F>,
360	backend: &Backend,
361) -> Result<(), Error>
362where
363	U: UnderlierType + PackScalar<F>,
364	F: TowerField,
365	Backend: ComputationBackend,
366{
367	let dedup_eval_points = metas
368		.iter()
369		.zip(projected_bivariate_claims)
370		.filter(|(meta, _)| meta.inner_id.is_none())
371		.map(|(_, claim)| claim.eval_point.as_ref())
372		.collect::<HashSet<_>>();
373
374	memoized_queries
375		.memoize_query_par(&dedup_eval_points.iter().copied().collect::<Vec<_>>(), backend)?;
376
377	let eq_indicators = dedup_eval_points
378		.into_iter()
379		.map(|eval_point| {
380			let mle = MLEDirectAdapter::from(MultilinearExtension::from_values(
381				memoized_queries
382					.full_query_readonly(eval_point)
383					.expect("computed above")
384					.expansion()
385					.to_vec(),
386			)?)
387			.upcast_arc_dyn();
388			Ok((eval_point, mle))
389		})
390		.collect::<Result<HashMap<_, _>, Error>>()?;
391
392	for (meta, claim) in metas
393		.iter()
394		.zip(projected_bivariate_claims)
395		.filter(|(meta, _)| meta.inner_id.is_none())
396	{
397		let eq_ind = eq_indicators
398			.get(claim.eval_point.as_ref())
399			.expect("was added above");
400
401		witness_index.update_multilin_poly(vec![(meta.multiplier_id, eq_ind.clone())])?;
402	}
403
404	Ok(())
405}
406
407#[allow(clippy::type_complexity)]
408pub struct MemoizedQueries<P: PackedField, Backend: ComputationBackend> {
409	memo: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
410}
411
412impl<P: PackedField, Backend: ComputationBackend> MemoizedQueries<P, Backend> {
413	#[allow(clippy::new_without_default)]
414	pub const fn new() -> Self {
415		Self { memo: Vec::new() }
416	}
417
418	/// Constructs `MemoizedQueries` from a list of eval_points and corresponding MultilinearQueries.
419	/// Assumes that each `eval_point` is given at most once.
420	/// Does not check that the input is valid.
421	#[allow(clippy::type_complexity)]
422	pub const fn new_from_known_queries(
423		data: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
424	) -> Self {
425		Self { memo: data }
426	}
427
428	pub fn full_query(
429		&mut self,
430		eval_point: &[P::Scalar],
431		backend: &Backend,
432	) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
433		if let Some(index) = self
434			.memo
435			.iter()
436			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
437		{
438			let (_, ref query) = &self.memo[index];
439			return Ok(query);
440		}
441
442		let query = backend.multilinear_query(eval_point)?;
443		self.memo.push((eval_point.to_vec(), query));
444
445		let (_, ref query) = self.memo.last().expect("pushed query immediately above");
446		Ok(query)
447	}
448
449	/// Finds a `MultilinearQuery` corresponding to the given `eval_point`.
450	pub fn full_query_readonly(
451		&self,
452		eval_point: &[P::Scalar],
453	) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
454		self.memo
455			.iter()
456			.position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
457			.map(|index| {
458				let (_, ref query) = &self.memo[index];
459				query
460			})
461	}
462
463	pub fn memoize_query_par(
464		&mut self,
465		eval_points: &[&[P::Scalar]],
466		backend: &Backend,
467	) -> Result<(), Error> {
468		let deduplicated_eval_points = eval_points.iter().collect::<HashSet<_>>();
469
470		let new_queries = deduplicated_eval_points
471			.into_par_iter()
472			.filter(|ep| self.full_query_readonly(ep).is_none())
473			.map(|ep| {
474				backend
475					.multilinear_query::<P>(ep)
476					.map(|res| (ep.to_vec(), res))
477					.map_err(Error::from)
478			})
479			.collect::<Result<Vec<_>, Error>>()?;
480
481		self.memo.extend(new_queries);
482
483		Ok(())
484	}
485}
486
487type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
488
489pub fn prove_bivariate_sumchecks_with_switchover<U, F, DomainField, Transcript, Backend>(
490	witness: &MultilinearExtensionIndex<U, F>,
491	constraint_sets: Vec<ConstraintSet<F>>,
492	transcript: &mut ProverTranscript<Transcript>,
493	switchover_fn: impl Fn(usize) -> usize + 'static,
494	domain_factory: impl EvaluationDomainFactory<DomainField>,
495	backend: &Backend,
496) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
497where
498	U: UnderlierType + PackScalar<F> + PackScalar<DomainField>,
499	F: TowerField + ExtensionField<DomainField>,
500	DomainField: Field,
501	Transcript: Challenger,
502	Backend: ComputationBackend,
503{
504	let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
505		constraint_sets,
506		witness,
507		domain_factory,
508		&switchover_fn,
509		backend,
510	)?;
511
512	let sumcheck_output = sumcheck::batch_prove(provers, transcript)?;
513
514	let evalcheck_claims = sumcheck::make_eval_claims(metas, sumcheck_output)?;
515
516	Ok(evalcheck_claims)
517}