1use std::collections::HashSet;
13
14use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
15use binius_hal::ComputationBackend;
16use binius_math::{
17 ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
18 MultilinearExtension, MultilinearQuery,
19};
20use binius_maybe_rayon::prelude::*;
21use binius_utils::bail;
22use bytemuck::zeroed_vec;
23use itertools::izip;
24use tracing::instrument;
25
26use super::{EvalPoint, EvalPointOracleIdMap, error::Error, evalcheck::EvalcheckMultilinearClaim};
27use crate::{
28 fiat_shamir::Challenger,
29 oracle::{
30 CompositeMLE, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId,
31 Packed, Shifted, SizedConstraintSet,
32 },
33 polynomial::MultivariatePoly,
34 protocols::sumcheck::{
35 self, Error as SumcheckError,
36 prove::{
37 front_loaded,
38 oracles::{
39 MLECheckProverWithMeta, SumcheckProversWithMetas,
40 constraint_sets_mlecheck_prover_meta, constraint_sets_sumcheck_provers_metas,
41 },
42 },
43 },
44 transcript::ProverTranscript,
45 transparent::{shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis},
46 witness::{MultilinearExtensionIndex, MultilinearWitness},
47};
48
49pub fn shifted_sumcheck_meta<F: TowerField>(
53 oracles: &mut MultilinearOracleSet<F>,
54 shifted: &Shifted,
55 eval_point: &[F],
56) -> Result<ProjectedBivariateMeta, Error> {
57 projected_bivariate_meta(
58 oracles,
59 shifted.id(),
60 shifted.block_size(),
61 eval_point,
62 |projected_eval_point| {
63 Ok(ShiftIndPartialEval::new(
64 shifted.block_size(),
65 shifted.shift_offset(),
66 shifted.shift_variant(),
67 projected_eval_point.to_vec(),
68 )?)
69 },
70 )
71}
72
73#[allow(clippy::too_many_arguments)]
76pub fn process_shifted_sumcheck<F, P>(
77 shifted: &Shifted,
78 meta: &ProjectedBivariateMeta,
79 eval_point: &[F],
80 eval: F,
81 witness_index: &mut MultilinearExtensionIndex<P>,
82 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
83 partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
84) -> Result<(), Error>
85where
86 P: PackedField<Scalar = F>,
87 F: TowerField,
88{
89 process_projected_bivariate_witness(
90 witness_index,
91 meta,
92 eval_point,
93 |projected_eval_point| {
94 let shift_ind = ShiftIndPartialEval::new(
95 projected_eval_point.len(),
96 shifted.shift_offset(),
97 shifted.shift_variant(),
98 projected_eval_point.to_vec(),
99 )?;
100
101 let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
102 Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
103 },
104 partial_evals,
105 )?;
106 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
107
108 Ok(())
109}
110
111pub fn packed_sumcheck_meta<F: TowerField>(
116 oracles: &mut MultilinearOracleSet<F>,
117 packed: &Packed,
118 eval_point: &[F],
119) -> Result<ProjectedBivariateMeta, Error> {
120 let n_vars = oracles.n_vars(packed.id());
121 let log_degree = packed.log_degree();
122 let binary_tower_level = oracles[packed.id()].binary_tower_level();
123
124 if log_degree > n_vars {
125 bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
126 }
127
128 projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
130 Ok(TowerBasis::new(log_degree, binary_tower_level)?)
131 })
132}
133
134pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
135 meta: &ProjectedBivariateMeta,
136 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
137 n_vars: usize,
138 eval: F,
139) {
140 if n_vars >= constraint_builders.len() {
141 constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
142 }
143 let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
144 constraint_builders[n_vars].add_sumcheck(meta.oracle_ids(), bivariate_product.into(), eval);
145}
146
147pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
148 position: usize,
149 eval_point: &EvalPoint<F>,
150 constraint_builders: &mut Vec<(EvalPoint<F>, ConstraintSetBuilder<F>)>,
151 comp: &CompositeMLE<F>,
152 eval: F,
153) {
154 let oracle_ids = comp.inner().clone();
155
156 if let Some((_, constraint_builder)) = constraint_builders.get_mut(position) {
157 constraint_builder.add_sumcheck(
158 oracle_ids,
159 <_ as CompositionPoly<F>>::expression(comp.c()),
160 eval,
161 );
162 } else {
163 let mut new_builder = ConstraintSetBuilder::new();
164 new_builder.add_sumcheck(oracle_ids, <_ as CompositionPoly<F>>::expression(comp.c()), eval);
165 constraint_builders.push((eval_point.clone(), new_builder));
166 }
167}
168
169#[allow(clippy::too_many_arguments)]
172pub fn process_packed_sumcheck<F, P>(
173 oracles: &MultilinearOracleSet<F>,
174 packed: &Packed,
175 meta: &ProjectedBivariateMeta,
176 eval_point: &[F],
177 eval: F,
178 witness_index: &mut MultilinearExtensionIndex<P>,
179 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
180 partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
181) -> Result<(), Error>
182where
183 P: PackedField<Scalar = F>,
184 F: TowerField,
185{
186 let log_degree = packed.log_degree();
187 let binary_tower_level = oracles[packed.id()].binary_tower_level();
188
189 process_projected_bivariate_witness(
190 witness_index,
191 meta,
192 eval_point,
193 |_projected_eval_point| {
194 let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
195 let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
196 Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
197 },
198 partial_evals,
199 )?;
200
201 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
202 Ok(())
203}
204
205#[derive(Clone, Copy)]
207pub struct ProjectedBivariateMeta {
208 inner_id: OracleId,
209 projected_id: Option<OracleId>,
210 multiplier_id: OracleId,
211 projected_n_vars: usize,
212}
213
214impl ProjectedBivariateMeta {
215 pub fn oracle_ids(&self) -> [OracleId; 2] {
216 [
217 self.projected_id.unwrap_or(self.inner_id),
218 self.multiplier_id,
219 ]
220 }
221}
222
223fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
224 oracles: &mut MultilinearOracleSet<F>,
225 inner_id: OracleId,
226 projected_n_vars: usize,
227 eval_point: &[F],
228 multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
229) -> Result<ProjectedBivariateMeta, Error> {
230 let inner = &oracles[inner_id];
231
232 let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
233 let projected_id =
234 oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
235
236 (&eval_point[..projected_n_vars], Some(projected_id))
237 } else {
238 (eval_point, None)
239 };
240
241 let projected_n_vars = projected_eval_point.len();
242
243 let multiplier_id =
244 oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
245
246 let meta = ProjectedBivariateMeta {
247 inner_id,
248 projected_id,
249 multiplier_id,
250 projected_n_vars,
251 };
252
253 Ok(meta)
254}
255
256fn process_projected_bivariate_witness<'a, F, P>(
257 witness_index: &mut MultilinearExtensionIndex<'a, P>,
258 meta: &ProjectedBivariateMeta,
259 eval_point: &[F],
260 multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
261 partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
262) -> Result<(), Error>
263where
264 P: PackedField<Scalar = F>,
265 F: TowerField,
266{
267 let &ProjectedBivariateMeta {
268 projected_id,
269 multiplier_id,
270 projected_n_vars,
271 inner_id,
272 } = meta;
273
274 let projected_eval_point = if let Some(projected_id) = projected_id {
275 let (prefix, suffix) = eval_point.split_at(projected_n_vars);
276
277 let projected = partial_evals
278 .get(inner_id, suffix)
279 .expect("projected should exist if projected_id exist")
280 .clone();
281
282 witness_index.update_multilin_poly(vec![(
283 projected_id,
284 MLEDirectAdapter::from(projected).upcast_arc_dyn(),
285 )])?;
286 prefix
287 } else {
288 eval_point
289 };
290
291 let m = multiplier_witness_ctr(projected_eval_point)?;
292
293 if !witness_index.has(multiplier_id) {
294 witness_index.update_multilin_poly([(multiplier_id, m)])?;
295 }
296 Ok(())
297}
298
299pub struct OracleIdPartialEval<P: PackedField> {
300 pub id: OracleId,
301 pub suffix: EvalPoint<P::Scalar>,
302 pub partial_eval: MultilinearExtension<P>,
303}
304
305pub fn try_build_partial_eval<F: TowerField, P: PackedField<Scalar = F>>(
306 partial_evals: &EvalPointOracleIdMap<MultilinearExtension<P>, F>,
307 oracles: &MultilinearOracleSet<F>,
308 id: OracleId,
309 suffix: &[F],
310 acc: &mut [P],
311 coeff: P,
312) -> bool {
313 match &oracles[id].variant {
314 crate::oracle::MultilinearPolyVariant::LinearCombination(lc) => {
315 for (poly_id, internal_coeff) in izip!(lc.polys(), lc.coefficients()) {
316 let new_coeff = coeff * P::broadcast(internal_coeff);
317
318 if !try_build_partial_eval(partial_evals, oracles, poly_id, suffix, acc, new_coeff)
319 {
320 return false;
321 }
322 }
323
324 if lc.offset() != F::zero() {
325 let offset = P::broadcast(lc.offset());
326 for acc in acc.iter_mut() {
327 *acc += offset;
328 }
329 }
330 }
331 _ => {
332 let mle = match partial_evals.get(id, suffix) {
333 Some(mle) => mle,
334 None => return false,
335 };
336 for (acc, eval) in acc.iter_mut().zip(mle.evals()) {
337 *acc += if coeff == P::one() {
338 *eval
339 } else {
340 *eval * coeff
341 };
342 }
343 }
344 };
345 true
346}
347
348#[allow(clippy::type_complexity)]
351#[instrument(
352 skip_all,
353 name = "Evalcheck::calculate_projected_mles",
354 level = "debug"
355)]
356pub fn collect_projected_mles<F, P>(
357 metas: &[ProjectedBivariateMeta],
358 memoized_queries: &mut MemoizedData<P>,
359 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
360 oracles: &MultilinearOracleSet<F>,
361 witness_index: &MultilinearExtensionIndex<P>,
362 partial_evals: &mut EvalPointOracleIdMap<MultilinearExtension<P>, F>,
363) -> Result<(), Error>
364where
365 P: PackedField<Scalar = F>,
366 F: TowerField,
367{
368 let mut suffix_oracle_id = HashSet::new();
369
370 for (claim, meta) in projected_bivariate_claims.iter().zip(metas.iter()) {
371 if meta.projected_id.is_some() {
372 let suffix = &claim.eval_point[meta.projected_n_vars..];
373 suffix_oracle_id.insert((suffix, meta.inner_id));
374 }
375 }
376
377 let queries_to_memoize = suffix_oracle_id
378 .iter()
379 .copied()
380 .map(|(suffix, _)| suffix)
381 .collect::<Vec<_>>();
382
383 memoized_queries.memoize_query_par(queries_to_memoize)?;
384
385 let suffix_oracle_id = suffix_oracle_id.into_iter().collect::<Vec<_>>();
386
387 let new_partial_evals = suffix_oracle_id
388 .into_par_iter()
389 .map(|(suffix, inner_id)| {
390 let inner_multilin = witness_index.get_multilin_poly(inner_id)?;
391
392 let query = memoized_queries
393 .full_query_readonly(suffix)
394 .ok_or(Error::MissingQuery)?;
395
396 if partial_evals.get(inner_id, suffix).is_some() {
397 return Ok(None);
398 }
399
400 let n_vars = inner_multilin.n_vars() - suffix.len();
401
402 let mut buffer = zeroed_vec(1 << n_vars.saturating_sub(P::LOG_WIDTH));
403
404 let is_built = try_build_partial_eval(
405 partial_evals,
406 oracles,
407 inner_id,
408 suffix,
409 &mut buffer,
410 P::one(),
411 );
412
413 let partial_eval = if is_built {
414 MultilinearExtension::new(n_vars, buffer).unwrap()
415 } else {
416 inner_multilin
417 .evaluate_partial_high(query.to_ref())
418 .map_err(Error::from)?
419 };
420
421 Ok(Some(OracleIdPartialEval {
422 id: inner_id,
423 suffix: suffix.into(),
424 partial_eval,
425 }))
426 })
427 .collect::<Result<Vec<Option<_>>, Error>>();
428
429 for OracleIdPartialEval {
430 id,
431 suffix,
432 partial_eval,
433 } in new_partial_evals?.into_iter().flatten()
434 {
435 partial_evals.insert(id, suffix, partial_eval)
436 }
437
438 Ok(())
439}
440
441#[allow(clippy::type_complexity)]
444pub struct MemoizedData<'a, P: PackedField> {
445 query: Vec<(Vec<P::Scalar>, MultilinearQuery<P>)>,
446 partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
447}
448
449impl<'a, P: PackedField> MemoizedData<'a, P> {
450 #[allow(clippy::new_without_default)]
451 pub fn new() -> Self {
452 Self {
453 query: Vec::new(),
454 partial_evals: EvalPointOracleIdMap::new(),
455 }
456 }
457
458 pub fn full_query(
459 &mut self,
460 eval_point: &[P::Scalar],
461 ) -> Result<&MultilinearQuery<P>, binius_hal::Error> {
462 if let Some(index) = self
463 .query
464 .iter()
465 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
466 {
467 let (_, query) = &self.query[index];
468 return Ok(query);
469 }
470
471 let query = MultilinearQuery::expand(eval_point);
472 self.query.push((eval_point.to_vec(), query));
473
474 let (_, query) = self.query.last().expect("pushed query immediately above");
475 Ok(query)
476 }
477
478 pub fn full_query_readonly(&self, eval_point: &[P::Scalar]) -> Option<&MultilinearQuery<P>> {
480 self.query
481 .iter()
482 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
483 .map(|index| {
484 let (_, query) = &self.query[index];
485 query
486 })
487 }
488
489 #[instrument(skip_all, name = "Evalcheck::memoize_query_par", level = "debug")]
490 pub fn memoize_query_par<'b>(
491 &mut self,
492 eval_points: impl IntoIterator<Item = &'b [P::Scalar]>,
493 ) -> Result<(), binius_hal::Error> {
494 let deduplicated_eval_points = eval_points.into_iter().collect::<HashSet<_>>();
495
496 let new_queries = deduplicated_eval_points
497 .into_par_iter()
498 .filter(|ep| self.full_query_readonly(ep).is_none())
499 .map(|ep| {
500 let query = MultilinearQuery::<P>::expand(ep);
501 (ep.to_vec(), query)
502 })
503 .collect::<Vec<_>>();
504
505 self.query.extend(new_queries);
506
507 Ok(())
508 }
509
510 pub fn memoize_partial_evals(
511 &mut self,
512 metas: &[ProjectedBivariateMeta],
513 projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
514 oracles: &mut MultilinearOracleSet<P::Scalar>,
515 witness_index: &MultilinearExtensionIndex<'a, P>,
516 ) where
517 P::Scalar: TowerField,
518 {
519 projected_bivariate_claims
520 .iter()
521 .zip(metas)
522 .for_each(|(claim, meta)| {
523 let inner_id = meta.inner_id;
524 if oracles[inner_id].variant.is_committed() && meta.projected_id.is_some() {
525 let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
526
527 let projected_id = meta.projected_id.expect("checked above");
528
529 let projected = witness_index
530 .get_multilin_poly(projected_id)
531 .expect("witness_index contains projected if projected_id exist");
532
533 self.partial_evals.insert(inner_id, eval_point, projected);
534 }
535 });
536 }
537
538 pub fn partial_eval(
539 &self,
540 id: OracleId,
541 eval_point: &[P::Scalar],
542 ) -> Option<&MultilinearWitness<'a, P>> {
543 self.partial_evals.get(id, eval_point)
544 }
545}
546
547type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
548
549pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
550 witness: &MultilinearExtensionIndex<P>,
551 constraint_sets: Vec<SizedConstraintSet<F>>,
552 transcript: &mut ProverTranscript<Transcript>,
553 switchover_fn: impl Fn(usize) -> usize + 'static,
554 domain_factory: impl EvaluationDomainFactory<DomainField>,
555 backend: &Backend,
556) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
557where
558 P: PackedField<Scalar = F>
559 + PackedExtension<F, PackedSubfield = P>
560 + PackedExtension<DomainField>,
561 F: TowerField + ExtensionField<DomainField>,
562 DomainField: Field,
563 Transcript: Challenger,
564 Backend: ComputationBackend,
565{
566 let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
567 EvaluationOrder::HighToLow,
568 constraint_sets,
569 witness,
570 domain_factory,
571 &switchover_fn,
572 backend,
573 )?;
574
575 let batch_prover = front_loaded::BatchProver::new(provers, transcript)?;
576
577 let mut sumcheck_output = batch_prover.run(transcript)?;
578
579 sumcheck_output.challenges.reverse();
581
582 let evalcheck_claims =
583 sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
584
585 Ok(evalcheck_claims)
586}
587
588#[allow(clippy::too_many_arguments)]
589pub fn prove_mlecheck_with_switchover<'a, F, P, DomainField, Transcript, Backend>(
590 witness: &MultilinearExtensionIndex<P>,
591 constraint_set: SizedConstraintSet<F>,
592 eq_ind_challenges: EvalPoint<F>,
593 memoized_data: &mut MemoizedData<'a, P>,
594 transcript: &mut ProverTranscript<Transcript>,
595 switchover_fn: impl Fn(usize) -> usize + 'static,
596 domain_factory: impl EvaluationDomainFactory<DomainField>,
597 backend: &Backend,
598) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
599where
600 P: PackedField<Scalar = F>
601 + PackedExtension<F, PackedSubfield = P>
602 + PackedExtension<DomainField>,
603 F: TowerField + ExtensionField<DomainField>,
604 DomainField: Field,
605 Transcript: Challenger,
606 Backend: ComputationBackend,
607{
608 let MLECheckProverWithMeta { prover, meta } = constraint_sets_mlecheck_prover_meta(
609 EvaluationOrder::HighToLow,
610 constraint_set,
611 eq_ind_challenges,
612 memoized_data,
613 witness,
614 domain_factory,
615 &switchover_fn,
616 backend,
617 )?;
618
619 let batch_prover = front_loaded::BatchProver::new(vec![prover], transcript)?;
620
621 let mut sumcheck_output = batch_prover.run(transcript)?;
622
623 sumcheck_output.challenges.reverse();
625
626 sumcheck_output.multilinear_evals[0].pop();
628
629 let evalcheck_claims =
630 sumcheck::make_eval_claims(EvaluationOrder::HighToLow, vec![meta], sumcheck_output)?;
631
632 Ok(evalcheck_claims)
633}