1use std::collections::{HashMap, HashSet};
13
14use binius_field::{
15 ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
16};
17use binius_hal::{ComputationBackend, ComputationBackendExt};
18use binius_math::{
19 ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
20 MultilinearExtension, MultilinearQuery,
21};
22use binius_maybe_rayon::prelude::*;
23use binius_utils::bail;
24
25use super::{error::Error, evalcheck::EvalcheckMultilinearClaim, EvalPointOracleIdMap};
26use crate::{
27 fiat_shamir::Challenger,
28 oracle::{
29 CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
30 MultilinearOracleSet, MultilinearPolyVariant, OracleId, Packed, Shifted,
31 },
32 polynomial::MultivariatePoly,
33 protocols::sumcheck::{
34 self,
35 prove::oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
36 Error as SumcheckError,
37 },
38 transcript::ProverTranscript,
39 transparent::{
40 eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
41 },
42 witness::{MultilinearExtensionIndex, MultilinearWitness},
43};
44
45pub fn shifted_sumcheck_meta<F: TowerField>(
49 oracles: &mut MultilinearOracleSet<F>,
50 shifted: &Shifted,
51 eval_point: &[F],
52) -> Result<ProjectedBivariateMeta, Error> {
53 projected_bivariate_meta(
54 oracles,
55 shifted.id(),
56 shifted.block_size(),
57 eval_point,
58 |projected_eval_point| {
59 Ok(ShiftIndPartialEval::new(
60 shifted.block_size(),
61 shifted.shift_offset(),
62 shifted.shift_variant(),
63 projected_eval_point.to_vec(),
64 )?)
65 },
66 )
67}
68
69#[allow(clippy::too_many_arguments)]
71pub fn process_shifted_sumcheck<F, P>(
72 shifted: &Shifted,
73 meta: &ProjectedBivariateMeta,
74 eval_point: &[F],
75 eval: F,
76 witness_index: &mut MultilinearExtensionIndex<P>,
77 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
78 projected: Option<MultilinearExtension<P>>,
79) -> Result<(), Error>
80where
81 P: PackedFieldIndexable<Scalar = F>,
82 F: TowerField,
83{
84 process_projected_bivariate_witness(
85 witness_index,
86 meta,
87 eval_point,
88 |projected_eval_point| {
89 let shift_ind = ShiftIndPartialEval::new(
90 projected_eval_point.len(),
91 shifted.shift_offset(),
92 shifted.shift_variant(),
93 projected_eval_point.to_vec(),
94 )?;
95
96 let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
97 Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
98 },
99 projected,
100 )?;
101 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
102
103 Ok(())
104}
105
106pub fn packed_sumcheck_meta<F: TowerField>(
111 oracles: &mut MultilinearOracleSet<F>,
112 packed: &Packed,
113 eval_point: &[F],
114) -> Result<ProjectedBivariateMeta, Error> {
115 let n_vars = oracles.n_vars(packed.id());
116 let log_degree = packed.log_degree();
117 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
118
119 if log_degree > n_vars {
120 bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
121 }
122
123 projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
125 Ok(TowerBasis::new(log_degree, binary_tower_level)?)
126 })
127}
128
129pub fn composite_sumcheck_meta<F: TowerField>(
130 oracles: &mut MultilinearOracleSet<F>,
131 eval_point: &[F],
132) -> Result<ProjectedBivariateMeta, Error> {
133 Ok(ProjectedBivariateMeta {
134 multiplier_id: oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?,
135 inner_id: None,
136 projected_id: None,
137 projected_n_vars: 0,
139 })
140}
141
142pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
143 meta: &ProjectedBivariateMeta,
144 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
145 n_vars: usize,
146 eval: F,
147) {
148 if n_vars > constraint_builders.len() {
149 constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
150 }
151 let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
152 constraint_builders[n_vars - 1].add_sumcheck(meta.oracle_ids(), bivariate_product, eval);
153}
154
155pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
156 meta: &ProjectedBivariateMeta,
157 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
158 comp: &CompositeMLE<F>,
159 eval: F,
160) {
161 let n_vars = comp.n_vars();
162 let mut oracle_ids = comp.inner().clone();
163 oracle_ids.push(meta.multiplier_id); let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithExpr::Var(comp.n_polys());
167 if n_vars > constraint_builders.len() {
168 constraint_builders.resize_with(n_vars, || ConstraintSetBuilder::new());
169 }
170 constraint_builders[n_vars - 1].add_sumcheck(oracle_ids, expr, eval);
171}
172
173#[allow(clippy::too_many_arguments)]
175pub fn process_packed_sumcheck<F, P>(
176 oracles: &MultilinearOracleSet<F>,
177 packed: &Packed,
178 meta: &ProjectedBivariateMeta,
179 eval_point: &[F],
180 eval: F,
181 witness_index: &mut MultilinearExtensionIndex<P>,
182 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
183 projected: Option<MultilinearExtension<P>>,
184) -> Result<(), Error>
185where
186 P: PackedField<Scalar = F>,
187 F: TowerField,
188{
189 let log_degree = packed.log_degree();
190 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
191
192 process_projected_bivariate_witness(
193 witness_index,
194 meta,
195 eval_point,
196 |_projected_eval_point| {
197 let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
198 let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
199 Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
200 },
201 projected,
202 )?;
203
204 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
205 Ok(())
206}
207
208#[derive(Clone, Copy)]
209pub struct ProjectedBivariateMeta {
210 inner_id: Option<OracleId>,
212 projected_id: Option<OracleId>,
213 multiplier_id: OracleId,
214 projected_n_vars: usize,
215}
216
217impl ProjectedBivariateMeta {
218 pub fn oracle_ids(&self) -> [OracleId; 2] {
219 [
220 self.projected_id.unwrap_or_else(|| {
221 self.inner_id
222 .expect("oracle_ids() is only defined for shifted / packed")
223 }),
224 self.multiplier_id,
225 ]
226 }
227}
228
229fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
230 oracles: &mut MultilinearOracleSet<F>,
231 inner_id: OracleId,
232 projected_n_vars: usize,
233 eval_point: &[F],
234 multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
235) -> Result<ProjectedBivariateMeta, Error> {
236 let inner = oracles.oracle(inner_id);
237
238 let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
239 let projected_id =
240 oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
241
242 (&eval_point[..projected_n_vars], Some(projected_id))
243 } else {
244 (eval_point, None)
245 };
246
247 let projected_n_vars = projected_eval_point.len();
248
249 let multiplier_id =
250 oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
251
252 let meta = ProjectedBivariateMeta {
253 inner_id: Some(inner_id),
254 projected_id,
255 multiplier_id,
256 projected_n_vars,
257 };
258
259 Ok(meta)
260}
261
262fn process_projected_bivariate_witness<'a, F, P>(
263 witness_index: &mut MultilinearExtensionIndex<'a, P>,
264 meta: &ProjectedBivariateMeta,
265 eval_point: &[F],
266 multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
267 projected: Option<MultilinearExtension<P>>,
268) -> Result<(), Error>
269where
270 P: PackedField<Scalar = F>,
271 F: TowerField,
272{
273 let &ProjectedBivariateMeta {
274 projected_id,
275 multiplier_id,
276 projected_n_vars,
277 ..
278 } = meta;
279
280 let projected_eval_point = if let Some(projected_id) = projected_id {
281 witness_index.update_multilin_poly(vec![(
282 projected_id,
283 MLEDirectAdapter::from(
284 projected.expect("projected should exist if projected_id exist"),
285 )
286 .upcast_arc_dyn(),
287 )])?;
288
289 &eval_point[..projected_n_vars]
290 } else {
291 eval_point
292 };
293
294 let m = multiplier_witness_ctr(projected_eval_point)?;
295
296 if !witness_index.has(multiplier_id) {
297 witness_index.update_multilin_poly([(multiplier_id, m)])?;
298 }
299 Ok(())
300}
301
302#[allow(clippy::type_complexity)]
305pub fn calculate_projected_mles<F, P, Backend>(
306 metas: &[ProjectedBivariateMeta],
307 memoized_queries: &mut MemoizedData<P, Backend>,
308 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
309 witness_index: &MultilinearExtensionIndex<P>,
310 backend: &Backend,
311) -> Result<Vec<Option<MultilinearExtension<P>>>, Error>
312where
313 P: PackedField<Scalar = F>,
314 F: TowerField,
315 Backend: ComputationBackend,
316{
317 let mut queries_to_memoize = Vec::new();
318 for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
319 if meta.inner_id.is_some() {
320 queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
322 }
323 }
324 memoized_queries.memoize_query_par(&queries_to_memoize, backend)?;
325
326 projected_bivariate_claims
327 .par_iter()
328 .zip(metas)
329 .map(|(claim, meta)| match (meta.inner_id, meta.projected_id) {
330 (Some(inner_id), Some(_)) => {
331 let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
332 let eval_point = &claim.eval_point[meta.projected_n_vars..];
333 let query = memoized_queries
334 .full_query_readonly(eval_point)
335 .ok_or(Error::MissingQuery)?;
336 Ok(Some(
337 backend
338 .evaluate_partial_high(&inner_multilin, query.to_ref())
339 .map_err(Error::from)?,
340 ))
341 }
342 _ => Ok(None),
343 })
344 .collect::<Result<Vec<Option<_>>, Error>>()
345}
346
347pub fn fill_eq_witness_for_composites<F, P, Backend>(
349 metas: &[ProjectedBivariateMeta],
350 memoized_queries: &mut MemoizedData<P, Backend>,
351 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
352 witness_index: &mut MultilinearExtensionIndex<P>,
353 backend: &Backend,
354) -> Result<(), Error>
355where
356 P: PackedField<Scalar = F>,
357 F: TowerField,
358 Backend: ComputationBackend,
359{
360 let dedup_eval_points = metas
361 .iter()
362 .zip(projected_bivariate_claims)
363 .filter(|(meta, _)| meta.inner_id.is_none())
364 .map(|(_, claim)| claim.eval_point.as_ref())
365 .collect::<HashSet<_>>();
366
367 memoized_queries
368 .memoize_query_par(&dedup_eval_points.iter().copied().collect::<Vec<_>>(), backend)?;
369
370 let eq_indicators = dedup_eval_points
371 .into_iter()
372 .map(|eval_point| {
373 let mle = MLEDirectAdapter::from(MultilinearExtension::from_values(
374 memoized_queries
375 .full_query_readonly(eval_point)
376 .expect("computed above")
377 .expansion()
378 .to_vec(),
379 )?)
380 .upcast_arc_dyn();
381 Ok((eval_point, mle))
382 })
383 .collect::<Result<HashMap<_, _>, Error>>()?;
384
385 for (meta, claim) in metas
386 .iter()
387 .zip(projected_bivariate_claims)
388 .filter(|(meta, _)| meta.inner_id.is_none())
389 {
390 let eq_ind = eq_indicators
391 .get(claim.eval_point.as_ref())
392 .expect("was added above");
393
394 witness_index.update_multilin_poly(vec![(meta.multiplier_id, eq_ind.clone())])?;
395 }
396
397 Ok(())
398}
399
400#[allow(clippy::type_complexity)]
401pub struct MemoizedData<'a, P: PackedField, Backend: ComputationBackend> {
402 query: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
403 partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
404}
405
406impl<'a, P: PackedField, Backend: ComputationBackend> MemoizedData<'a, P, Backend> {
407 #[allow(clippy::new_without_default)]
408 pub fn new() -> Self {
409 Self {
410 query: Vec::new(),
411 partial_evals: EvalPointOracleIdMap::new(),
412 }
413 }
414
415 pub fn full_query(
416 &mut self,
417 eval_point: &[P::Scalar],
418 backend: &Backend,
419 ) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
420 if let Some(index) = self
421 .query
422 .iter()
423 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
424 {
425 let (_, ref query) = &self.query[index];
426 return Ok(query);
427 }
428
429 let query = backend.multilinear_query(eval_point)?;
430 self.query.push((eval_point.to_vec(), query));
431
432 let (_, ref query) = self.query.last().expect("pushed query immediately above");
433 Ok(query)
434 }
435
436 pub fn full_query_readonly(
438 &self,
439 eval_point: &[P::Scalar],
440 ) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
441 self.query
442 .iter()
443 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
444 .map(|index| {
445 let (_, ref query) = &self.query[index];
446 query
447 })
448 }
449
450 pub fn memoize_query_par(
451 &mut self,
452 eval_points: &[&[P::Scalar]],
453 backend: &Backend,
454 ) -> Result<(), binius_hal::Error> {
455 let deduplicated_eval_points = eval_points.iter().collect::<HashSet<_>>();
456
457 let new_queries = deduplicated_eval_points
458 .into_par_iter()
459 .filter(|ep| self.full_query_readonly(ep).is_none())
460 .map(|ep| {
461 backend
462 .multilinear_query::<P>(ep)
463 .map(|res| (ep.to_vec(), res))
464 })
465 .collect::<Result<Vec<_>, binius_hal::Error>>()?;
466
467 self.query.extend(new_queries);
468
469 Ok(())
470 }
471
472 pub fn memoize_partial_evals(
473 &mut self,
474 metas: &[ProjectedBivariateMeta],
475 projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
476 oracles: &mut MultilinearOracleSet<P::Scalar>,
477 witness_index: &MultilinearExtensionIndex<'a, P>,
478 ) where
479 P::Scalar: TowerField,
480 {
481 projected_bivariate_claims
482 .iter()
483 .zip(metas)
484 .filter(|(_, meta)| meta.inner_id.is_some())
485 .for_each(|(claim, meta)| {
486 let inner_id = meta.inner_id.expect("filtered by Some");
487 if matches!(oracles.oracle(inner_id).variant, MultilinearPolyVariant::Committed)
488 && meta.projected_id.is_some()
489 {
490 let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
491
492 let projected_id = meta.projected_id.expect("checked above");
493
494 let projected = witness_index
495 .get_multilin_poly(projected_id)
496 .expect("witness_index contains projected if projected_id exist");
497
498 self.partial_evals.insert(inner_id, eval_point, projected);
499 }
500 });
501 }
502
503 pub fn partial_eval(
504 &self,
505 id: OracleId,
506 eval_point: &[P::Scalar],
507 ) -> Option<&MultilinearWitness<'a, P>> {
508 self.partial_evals.get(id, eval_point)
509 }
510}
511
512type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
513
514pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
515 witness: &MultilinearExtensionIndex<P>,
516 constraint_sets: Vec<ConstraintSet<F>>,
517 transcript: &mut ProverTranscript<Transcript>,
518 switchover_fn: impl Fn(usize) -> usize + 'static,
519 domain_factory: impl EvaluationDomainFactory<DomainField>,
520 backend: &Backend,
521) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
522where
523 P: PackedField<Scalar = F>
524 + PackedExtension<F, PackedSubfield = P>
525 + PackedExtension<DomainField>,
526 F: TowerField + ExtensionField<DomainField>,
527 DomainField: Field,
528 Transcript: Challenger,
529 Backend: ComputationBackend,
530{
531 let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
532 EvaluationOrder::HighToLow,
533 constraint_sets,
534 witness,
535 domain_factory,
536 &switchover_fn,
537 backend,
538 )?;
539
540 let sumcheck_output = sumcheck::batch_prove(provers, transcript)?;
541
542 let evalcheck_claims =
543 sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
544
545 Ok(evalcheck_claims)
546}