1use std::collections::{HashMap, HashSet};
13
14use binius_field::{
15 as_packed_field::{PackScalar, PackedType},
16 underlier::UnderlierType,
17 ExtensionField, Field, PackedField, PackedFieldIndexable, TowerField,
18};
19use binius_hal::{ComputationBackend, ComputationBackendExt};
20use binius_math::{
21 ArithExpr, CompositionPoly, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension,
22 MultilinearQuery,
23};
24use binius_maybe_rayon::prelude::*;
25use binius_utils::bail;
26
27use super::{error::Error, evalcheck::EvalcheckMultilinearClaim};
28use crate::{
29 fiat_shamir::Challenger,
30 oracle::{
31 CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
32 MultilinearOracleSet, OracleId, Packed, ProjectionVariant, Shifted,
33 },
34 polynomial::MultivariatePoly,
35 protocols::sumcheck::{
36 self,
37 prove::oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
38 Error as SumcheckError,
39 },
40 transcript::ProverTranscript,
41 transparent::{
42 eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
43 },
44 witness::{MultilinearExtensionIndex, MultilinearWitness},
45};
46
47pub fn shifted_sumcheck_meta<F: TowerField>(
51 oracles: &mut MultilinearOracleSet<F>,
52 shifted: &Shifted,
53 eval_point: &[F],
54) -> Result<ProjectedBivariateMeta, Error> {
55 projected_bivariate_meta(
56 oracles,
57 shifted.id(),
58 shifted.block_size(),
59 eval_point,
60 |projected_eval_point| {
61 Ok(ShiftIndPartialEval::new(
62 shifted.block_size(),
63 shifted.shift_offset(),
64 shifted.shift_variant(),
65 projected_eval_point.to_vec(),
66 )?)
67 },
68 )
69}
70
71#[allow(clippy::too_many_arguments)]
73pub fn process_shifted_sumcheck<U, F>(
74 shifted: &Shifted,
75 meta: ProjectedBivariateMeta,
76 eval_point: &[F],
77 eval: F,
78 witness_index: &mut MultilinearExtensionIndex<U, F>,
79 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
80 projected: MultilinearExtension<PackedType<U, F>>,
81) -> Result<(), Error>
82where
83 PackedType<U, F>: PackedFieldIndexable,
84 U: UnderlierType + PackScalar<F>,
85 F: TowerField,
86{
87 process_projected_bivariate_witness(
88 witness_index,
89 meta,
90 eval_point,
91 |projected_eval_point| {
92 let shift_ind = ShiftIndPartialEval::new(
93 projected_eval_point.len(),
94 shifted.shift_offset(),
95 shifted.shift_variant(),
96 projected_eval_point.to_vec(),
97 )?;
98
99 let shift_ind_mle = shift_ind.multilinear_extension::<PackedType<U, F>>()?;
100 Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
101 },
102 projected,
103 )?;
104 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
105
106 Ok(())
107}
108
109pub fn packed_sumcheck_meta<F: TowerField>(
114 oracles: &mut MultilinearOracleSet<F>,
115 packed: &Packed,
116 eval_point: &[F],
117) -> Result<ProjectedBivariateMeta, Error> {
118 let n_vars = oracles.n_vars(packed.id());
119 let log_degree = packed.log_degree();
120 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
121
122 if log_degree > n_vars {
123 bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
124 }
125
126 projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
128 Ok(TowerBasis::new(log_degree, binary_tower_level)?)
129 })
130}
131
132pub fn composite_sumcheck_meta<F: TowerField>(
133 oracles: &mut MultilinearOracleSet<F>,
134 eval_point: &[F],
135) -> Result<ProjectedBivariateMeta, Error> {
136 Ok(ProjectedBivariateMeta {
137 multiplier_id: oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?,
138 inner_id: None,
139 projected_id: None,
140 projected_n_vars: eval_point.len(),
141 })
142}
143
144pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
145 meta: ProjectedBivariateMeta,
146 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
147 n_vars: usize,
148 eval: F,
149) {
150 if n_vars > constraint_builders.len() {
151 constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
152 }
153 let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
154 constraint_builders[n_vars - 1].add_sumcheck(meta.oracle_ids(), bivariate_product, eval);
155}
156
157pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
158 meta: ProjectedBivariateMeta,
159 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
160 comp: &CompositeMLE<F>,
161 eval: F,
162) {
163 let n_vars = comp.n_vars();
164 let mut oracle_ids = comp.inner().clone();
165 oracle_ids.push(meta.multiplier_id); let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithExpr::Var(comp.n_polys());
169 if n_vars > constraint_builders.len() {
170 constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
171 }
172 constraint_builders[n_vars - 1].add_sumcheck(oracle_ids, expr, eval);
173}
174
175#[allow(clippy::too_many_arguments)]
177pub fn process_packed_sumcheck<U, F>(
178 oracles: &MultilinearOracleSet<F>,
179 packed: &Packed,
180 meta: ProjectedBivariateMeta,
181 eval_point: &[F],
182 eval: F,
183 witness_index: &mut MultilinearExtensionIndex<U, F>,
184 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
185 projected: MultilinearExtension<PackedType<U, F>>,
186) -> Result<(), Error>
187where
188 U: UnderlierType + PackScalar<F>,
189 F: TowerField,
190{
191 let log_degree = packed.log_degree();
192 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
193
194 process_projected_bivariate_witness(
195 witness_index,
196 meta,
197 eval_point,
198 |_projected_eval_point| {
199 let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
200 let tower_basis_mle = tower_basis.multilinear_extension::<PackedType<U, F>>()?;
201 Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
202 },
203 projected,
204 )?;
205
206 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
207 Ok(())
208}
209
210#[derive(Clone, Copy)]
211pub struct ProjectedBivariateMeta {
212 inner_id: Option<OracleId>,
214 projected_id: Option<OracleId>,
215 multiplier_id: OracleId,
216 projected_n_vars: usize,
217}
218
219impl ProjectedBivariateMeta {
220 pub fn oracle_ids(&self) -> [OracleId; 2] {
221 [
222 self.projected_id.unwrap_or_else(|| {
223 self.inner_id
224 .expect("oracle_ids() is only defined for shifted / packed")
225 }),
226 self.multiplier_id,
227 ]
228 }
229}
230
231fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
232 oracles: &mut MultilinearOracleSet<F>,
233 inner_id: OracleId,
234 projected_n_vars: usize,
235 eval_point: &[F],
236 multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
237) -> Result<ProjectedBivariateMeta, Error> {
238 let inner = oracles.oracle(inner_id);
239
240 let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
241 let projected_id = oracles.add_projected(
242 inner_id,
243 eval_point[projected_n_vars..].to_vec(),
244 ProjectionVariant::LastVars,
245 )?;
246
247 (&eval_point[..projected_n_vars], Some(projected_id))
248 } else {
249 (eval_point, None)
250 };
251
252 let projected_n_vars = projected_eval_point.len();
253
254 let multiplier_id =
255 oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
256
257 let meta = ProjectedBivariateMeta {
258 inner_id: Some(inner_id),
259 projected_id,
260 multiplier_id,
261 projected_n_vars,
262 };
263
264 Ok(meta)
265}
266
267fn process_projected_bivariate_witness<'a, U, F>(
268 witness_index: &mut MultilinearExtensionIndex<'a, U, F>,
269 meta: ProjectedBivariateMeta,
270 eval_point: &[F],
271 multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, PackedType<U, F>>, Error>,
272 projected: MultilinearExtension<PackedType<U, F>>,
273) -> Result<(), Error>
274where
275 U: UnderlierType + PackScalar<F>,
276 F: TowerField,
277{
278 let ProjectedBivariateMeta {
279 projected_id,
280 multiplier_id,
281 projected_n_vars,
282 ..
283 } = meta;
284
285 let projected_eval_point = if let Some(projected_id) = projected_id {
286 witness_index.update_multilin_poly(vec![(
287 projected_id,
288 MLEDirectAdapter::from(projected).upcast_arc_dyn(),
289 )])?;
290
291 &eval_point[..projected_n_vars]
292 } else {
293 eval_point
294 };
295
296 let m = multiplier_witness_ctr(projected_eval_point)?;
297
298 if !witness_index.has(multiplier_id) {
299 witness_index.update_multilin_poly(vec![(multiplier_id, m)])?;
300 }
301 Ok(())
302}
303
304#[allow(clippy::type_complexity)]
307pub fn calculate_projected_mles<U, F, Backend>(
308 metas: &[ProjectedBivariateMeta],
309 memoized_queries: &mut MemoizedQueries<PackedType<U, F>, Backend>,
310 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
311 witness_index: &MultilinearExtensionIndex<U, F>,
312 backend: &Backend,
313) -> Result<Vec<Option<MultilinearExtension<PackedType<U, F>>>>, Error>
314where
315 U: UnderlierType + PackScalar<F>,
316 F: TowerField,
317 Backend: ComputationBackend,
318{
319 let mut queries_to_memoize = Vec::new();
320 for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
321 if meta.inner_id.is_some() {
322 queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
324 }
325 }
326 memoized_queries.memoize_query_par(&queries_to_memoize, backend)?;
327
328 projected_bivariate_claims
329 .par_iter()
330 .zip(metas)
331 .map(|(claim, meta)| {
332 match meta.inner_id {
333 Some(inner_id) => {
334 {
335 let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
337 let eval_point = &claim.eval_point[meta.projected_n_vars..];
338 let query = memoized_queries
339 .full_query_readonly(eval_point)
340 .ok_or(Error::MissingQuery)?;
341 Ok(Some(
342 backend
343 .evaluate_partial_high(&inner_multilin, query.to_ref())
344 .map_err(Error::from)?,
345 ))
346 }
347 }
348 None => Ok(None), }
350 })
351 .collect::<Result<Vec<Option<_>>, Error>>()
352}
353
354pub fn fill_eq_witness_for_composites<U, F, Backend>(
356 metas: &[ProjectedBivariateMeta],
357 memoized_queries: &mut MemoizedQueries<PackedType<U, F>, Backend>,
358 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
359 witness_index: &mut MultilinearExtensionIndex<U, F>,
360 backend: &Backend,
361) -> Result<(), Error>
362where
363 U: UnderlierType + PackScalar<F>,
364 F: TowerField,
365 Backend: ComputationBackend,
366{
367 let dedup_eval_points = metas
368 .iter()
369 .zip(projected_bivariate_claims)
370 .filter(|(meta, _)| meta.inner_id.is_none())
371 .map(|(_, claim)| claim.eval_point.as_ref())
372 .collect::<HashSet<_>>();
373
374 memoized_queries
375 .memoize_query_par(&dedup_eval_points.iter().copied().collect::<Vec<_>>(), backend)?;
376
377 let eq_indicators = dedup_eval_points
378 .into_iter()
379 .map(|eval_point| {
380 let mle = MLEDirectAdapter::from(MultilinearExtension::from_values(
381 memoized_queries
382 .full_query_readonly(eval_point)
383 .expect("computed above")
384 .expansion()
385 .to_vec(),
386 )?)
387 .upcast_arc_dyn();
388 Ok((eval_point, mle))
389 })
390 .collect::<Result<HashMap<_, _>, Error>>()?;
391
392 for (meta, claim) in metas
393 .iter()
394 .zip(projected_bivariate_claims)
395 .filter(|(meta, _)| meta.inner_id.is_none())
396 {
397 let eq_ind = eq_indicators
398 .get(claim.eval_point.as_ref())
399 .expect("was added above");
400
401 witness_index.update_multilin_poly(vec![(meta.multiplier_id, eq_ind.clone())])?;
402 }
403
404 Ok(())
405}
406
407#[allow(clippy::type_complexity)]
408pub struct MemoizedQueries<P: PackedField, Backend: ComputationBackend> {
409 memo: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
410}
411
412impl<P: PackedField, Backend: ComputationBackend> MemoizedQueries<P, Backend> {
413 #[allow(clippy::new_without_default)]
414 pub const fn new() -> Self {
415 Self { memo: Vec::new() }
416 }
417
418 #[allow(clippy::type_complexity)]
422 pub const fn new_from_known_queries(
423 data: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
424 ) -> Self {
425 Self { memo: data }
426 }
427
428 pub fn full_query(
429 &mut self,
430 eval_point: &[P::Scalar],
431 backend: &Backend,
432 ) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
433 if let Some(index) = self
434 .memo
435 .iter()
436 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
437 {
438 let (_, ref query) = &self.memo[index];
439 return Ok(query);
440 }
441
442 let query = backend.multilinear_query(eval_point)?;
443 self.memo.push((eval_point.to_vec(), query));
444
445 let (_, ref query) = self.memo.last().expect("pushed query immediately above");
446 Ok(query)
447 }
448
449 pub fn full_query_readonly(
451 &self,
452 eval_point: &[P::Scalar],
453 ) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
454 self.memo
455 .iter()
456 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
457 .map(|index| {
458 let (_, ref query) = &self.memo[index];
459 query
460 })
461 }
462
463 pub fn memoize_query_par(
464 &mut self,
465 eval_points: &[&[P::Scalar]],
466 backend: &Backend,
467 ) -> Result<(), Error> {
468 let deduplicated_eval_points = eval_points.iter().collect::<HashSet<_>>();
469
470 let new_queries = deduplicated_eval_points
471 .into_par_iter()
472 .filter(|ep| self.full_query_readonly(ep).is_none())
473 .map(|ep| {
474 backend
475 .multilinear_query::<P>(ep)
476 .map(|res| (ep.to_vec(), res))
477 .map_err(Error::from)
478 })
479 .collect::<Result<Vec<_>, Error>>()?;
480
481 self.memo.extend(new_queries);
482
483 Ok(())
484 }
485}
486
487type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
488
489pub fn prove_bivariate_sumchecks_with_switchover<U, F, DomainField, Transcript, Backend>(
490 witness: &MultilinearExtensionIndex<U, F>,
491 constraint_sets: Vec<ConstraintSet<F>>,
492 transcript: &mut ProverTranscript<Transcript>,
493 switchover_fn: impl Fn(usize) -> usize + 'static,
494 domain_factory: impl EvaluationDomainFactory<DomainField>,
495 backend: &Backend,
496) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
497where
498 U: UnderlierType + PackScalar<F> + PackScalar<DomainField>,
499 F: TowerField + ExtensionField<DomainField>,
500 DomainField: Field,
501 Transcript: Challenger,
502 Backend: ComputationBackend,
503{
504 let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
505 constraint_sets,
506 witness,
507 domain_factory,
508 &switchover_fn,
509 backend,
510 )?;
511
512 let sumcheck_output = sumcheck::batch_prove(provers, transcript)?;
513
514 let evalcheck_claims = sumcheck::make_eval_claims(metas, sumcheck_output)?;
515
516 Ok(evalcheck_claims)
517}