1use std::{
13 collections::{HashMap, HashSet},
14 iter,
15};
16
17use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
18use binius_hal::{ComputationBackend, ComputationBackendExt};
19use binius_math::{
20 ArithCircuit, ArithExpr, CompositionPoly, EvaluationDomainFactory, EvaluationOrder,
21 MLEDirectAdapter, MultilinearExtension, MultilinearQuery,
22};
23use binius_maybe_rayon::prelude::*;
24use binius_utils::bail;
25use tracing::instrument;
26
27use super::{error::Error, evalcheck::EvalcheckMultilinearClaim, EvalPointOracleIdMap};
28use crate::{
29 fiat_shamir::Challenger,
30 oracle::{
31 CompositeMLE, ConstraintSet, ConstraintSetBuilder, Error as OracleError,
32 MultilinearOracleSet, MultilinearPolyVariant, OracleId, Packed, Shifted,
33 },
34 polynomial::MultivariatePoly,
35 protocols::sumcheck::{
36 self,
37 prove::{
38 front_loaded,
39 oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas},
40 },
41 Error as SumcheckError,
42 },
43 transcript::ProverTranscript,
44 transparent::{
45 eq_ind::EqIndPartialEval, shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis,
46 },
47 witness::{MultilinearExtensionIndex, MultilinearWitness},
48};
49
50pub fn shifted_sumcheck_meta<F: TowerField>(
54 oracles: &mut MultilinearOracleSet<F>,
55 shifted: &Shifted,
56 eval_point: &[F],
57) -> Result<ProjectedBivariateMeta, Error> {
58 projected_bivariate_meta(
59 oracles,
60 shifted.id(),
61 shifted.block_size(),
62 eval_point,
63 |projected_eval_point| {
64 Ok(ShiftIndPartialEval::new(
65 shifted.block_size(),
66 shifted.shift_offset(),
67 shifted.shift_variant(),
68 projected_eval_point.to_vec(),
69 )?)
70 },
71 )
72}
73
74#[allow(clippy::too_many_arguments)]
76pub fn process_shifted_sumcheck<F, P>(
77 shifted: &Shifted,
78 meta: &ProjectedBivariateMeta,
79 eval_point: &[F],
80 eval: F,
81 witness_index: &mut MultilinearExtensionIndex<P>,
82 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
83 projected: Option<MultilinearExtension<P>>,
84) -> Result<(), Error>
85where
86 P: PackedField<Scalar = F>,
87 F: TowerField,
88{
89 process_projected_bivariate_witness(
90 witness_index,
91 meta,
92 eval_point,
93 |projected_eval_point| {
94 let shift_ind = ShiftIndPartialEval::new(
95 projected_eval_point.len(),
96 shifted.shift_offset(),
97 shifted.shift_variant(),
98 projected_eval_point.to_vec(),
99 )?;
100
101 let shift_ind_mle = shift_ind.multilinear_extension::<P>()?;
102 Ok(MLEDirectAdapter::from(shift_ind_mle).upcast_arc_dyn())
103 },
104 projected,
105 )?;
106 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, shifted.block_size(), eval);
107
108 Ok(())
109}
110
111pub fn packed_sumcheck_meta<F: TowerField>(
116 oracles: &mut MultilinearOracleSet<F>,
117 packed: &Packed,
118 eval_point: &[F],
119) -> Result<ProjectedBivariateMeta, Error> {
120 let n_vars = oracles.n_vars(packed.id());
121 let log_degree = packed.log_degree();
122 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
123
124 if log_degree > n_vars {
125 bail!(OracleError::NotEnoughVarsForPacking { n_vars, log_degree });
126 }
127
128 projected_bivariate_meta(oracles, packed.id(), 0, eval_point, |_| {
130 Ok(TowerBasis::new(log_degree, binary_tower_level)?)
131 })
132}
133
134pub fn composite_mlecheck_meta<F: TowerField>(
135 oracles: &mut MultilinearOracleSet<F>,
136 eval_point: &[F],
137) -> Result<CompositeMLECheckMeta, Error> {
138 let eq_ind_id = oracles.add_transparent(EqIndPartialEval::new(eval_point.to_vec()))?;
139 Ok(CompositeMLECheckMeta { eq_ind_id })
140}
141
142pub fn add_bivariate_sumcheck_to_constraints<F: TowerField>(
143 meta: &ProjectedBivariateMeta,
144 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
145 n_vars: usize,
146 eval: F,
147) {
148 if n_vars >= constraint_builders.len() {
149 constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
150 }
151 let bivariate_product = ArithExpr::Var(0) * ArithExpr::Var(1);
152 constraint_builders[n_vars].add_sumcheck(meta.oracle_ids(), bivariate_product.into(), eval);
153}
154
155pub fn add_composite_sumcheck_to_constraints<F: TowerField>(
156 meta: CompositeMLECheckMeta,
157 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
158 comp: &CompositeMLE<F>,
159 eval: F,
160) {
161 let n_vars = comp.n_vars();
162 let mut oracle_ids = comp.inner().clone();
163 oracle_ids.push(meta.eq_ind_id); let expr = <_ as CompositionPoly<F>>::expression(comp.c()) * ArithCircuit::var(comp.n_polys());
167 if n_vars >= constraint_builders.len() {
168 constraint_builders.resize_with(n_vars + 1, || ConstraintSetBuilder::new());
169 }
170 constraint_builders[n_vars].add_sumcheck(oracle_ids, expr, eval);
171}
172
173#[allow(clippy::too_many_arguments)]
175pub fn process_packed_sumcheck<F, P>(
176 oracles: &MultilinearOracleSet<F>,
177 packed: &Packed,
178 meta: &ProjectedBivariateMeta,
179 eval_point: &[F],
180 eval: F,
181 witness_index: &mut MultilinearExtensionIndex<P>,
182 constraint_builders: &mut Vec<ConstraintSetBuilder<F>>,
183 projected: Option<MultilinearExtension<P>>,
184) -> Result<(), Error>
185where
186 P: PackedField<Scalar = F>,
187 F: TowerField,
188{
189 let log_degree = packed.log_degree();
190 let binary_tower_level = oracles.oracle(packed.id()).binary_tower_level();
191
192 process_projected_bivariate_witness(
193 witness_index,
194 meta,
195 eval_point,
196 |_projected_eval_point| {
197 let tower_basis = TowerBasis::new(log_degree, binary_tower_level)?;
198 let tower_basis_mle = tower_basis.multilinear_extension::<P>()?;
199 Ok(MLEDirectAdapter::from(tower_basis_mle).upcast_arc_dyn())
200 },
201 projected,
202 )?;
203
204 add_bivariate_sumcheck_to_constraints(meta, constraint_builders, packed.log_degree(), eval);
205 Ok(())
206}
207
208#[derive(Debug, Clone, Copy)]
209pub struct CompositeMLECheckMeta {
210 pub eq_ind_id: OracleId,
211}
212
213#[derive(Clone, Copy)]
215pub struct ProjectedBivariateMeta {
216 inner_id: OracleId,
217 projected_id: Option<OracleId>,
218 multiplier_id: OracleId,
219 projected_n_vars: usize,
220}
221
222impl ProjectedBivariateMeta {
223 pub fn oracle_ids(&self) -> [OracleId; 2] {
224 [
225 self.projected_id.unwrap_or(self.inner_id),
226 self.multiplier_id,
227 ]
228 }
229}
230
231fn projected_bivariate_meta<F: TowerField, T: MultivariatePoly<F> + 'static>(
232 oracles: &mut MultilinearOracleSet<F>,
233 inner_id: OracleId,
234 projected_n_vars: usize,
235 eval_point: &[F],
236 multiplier_transparent_ctr: impl FnOnce(&[F]) -> Result<T, Error>,
237) -> Result<ProjectedBivariateMeta, Error> {
238 let inner = oracles.oracle(inner_id);
239
240 let (projected_eval_point, projected_id) = if projected_n_vars < inner.n_vars() {
241 let projected_id =
242 oracles.add_projected_last_vars(inner_id, eval_point[projected_n_vars..].to_vec())?;
243
244 (&eval_point[..projected_n_vars], Some(projected_id))
245 } else {
246 (eval_point, None)
247 };
248
249 let projected_n_vars = projected_eval_point.len();
250
251 let multiplier_id =
252 oracles.add_transparent(multiplier_transparent_ctr(projected_eval_point)?)?;
253
254 let meta = ProjectedBivariateMeta {
255 inner_id,
256 projected_id,
257 multiplier_id,
258 projected_n_vars,
259 };
260
261 Ok(meta)
262}
263
264fn process_projected_bivariate_witness<'a, F, P>(
265 witness_index: &mut MultilinearExtensionIndex<'a, P>,
266 meta: &ProjectedBivariateMeta,
267 eval_point: &[F],
268 multiplier_witness_ctr: impl FnOnce(&[F]) -> Result<MultilinearWitness<'a, P>, Error>,
269 projected: Option<MultilinearExtension<P>>,
270) -> Result<(), Error>
271where
272 P: PackedField<Scalar = F>,
273 F: TowerField,
274{
275 let &ProjectedBivariateMeta {
276 projected_id,
277 multiplier_id,
278 projected_n_vars,
279 ..
280 } = meta;
281
282 let projected_eval_point = if let Some(projected_id) = projected_id {
283 witness_index.update_multilin_poly(vec![(
284 projected_id,
285 MLEDirectAdapter::from(
286 projected.expect("projected should exist if projected_id exist"),
287 )
288 .upcast_arc_dyn(),
289 )])?;
290
291 &eval_point[..projected_n_vars]
292 } else {
293 eval_point
294 };
295
296 let m = multiplier_witness_ctr(projected_eval_point)?;
297
298 if !witness_index.has(multiplier_id) {
299 witness_index.update_multilin_poly([(multiplier_id, m)])?;
300 }
301 Ok(())
302}
303
304#[allow(clippy::type_complexity)]
307#[instrument(
308 skip_all,
309 name = "Evalcheck::calculate_projected_mles",
310 level = "debug"
311)]
312pub fn calculate_projected_mles<F, P, Backend>(
313 metas: &[ProjectedBivariateMeta],
314 memoized_queries: &mut MemoizedData<P, Backend>,
315 projected_bivariate_claims: &[EvalcheckMultilinearClaim<F>],
316 witness_index: &MultilinearExtensionIndex<P>,
317 backend: &Backend,
318) -> Result<Vec<Option<MultilinearExtension<P>>>, Error>
319where
320 P: PackedField<Scalar = F>,
321 F: TowerField,
322 Backend: ComputationBackend,
323{
324 let mut queries_to_memoize = Vec::new();
325 for (meta, claim) in metas.iter().zip(projected_bivariate_claims) {
326 queries_to_memoize.push(&claim.eval_point[meta.projected_n_vars..]);
327 }
328 memoized_queries.memoize_query_par(queries_to_memoize, backend)?;
329
330 projected_bivariate_claims
331 .par_iter()
332 .zip(metas)
333 .map(|(claim, meta)| match meta.projected_id {
334 Some(_) => {
335 let inner_multilin = witness_index.get_multilin_poly(meta.inner_id)?;
336 let eval_point = &claim.eval_point[meta.projected_n_vars..];
337 let query = memoized_queries
338 .full_query_readonly(eval_point)
339 .ok_or(Error::MissingQuery)?;
340 Ok(Some(
341 backend
342 .evaluate_partial_high(&inner_multilin, query.to_ref())
343 .map_err(Error::from)?,
344 ))
345 }
346 _ => Ok(None),
347 })
348 .collect::<Result<Vec<Option<_>>, Error>>()
349}
350
351pub fn fill_eq_witness_for_composites<F, P, Backend>(
353 metas: &[CompositeMLECheckMeta],
354 memoized_queries: &mut MemoizedData<P, Backend>,
355 composite_mle_claims: &[EvalcheckMultilinearClaim<F>],
356 witness_index: &mut MultilinearExtensionIndex<P>,
357 backend: &Backend,
358) -> Result<(), Error>
359where
360 P: PackedField<Scalar = F>,
361 F: TowerField,
362 Backend: ComputationBackend,
363{
364 let dedup_eval_points = composite_mle_claims
365 .iter()
366 .map(|claim| claim.eval_point.as_ref())
367 .collect::<HashSet<_>>();
368
369 memoized_queries.memoize_query_par(dedup_eval_points.iter().copied(), backend)?;
370
371 let eq_indicators = dedup_eval_points
372 .into_iter()
373 .map(|eval_point| {
374 let mle = MLEDirectAdapter::from(MultilinearExtension::new(
375 eval_point.len(),
376 memoized_queries
377 .full_query_readonly(eval_point)
378 .expect("computed above")
379 .expansion()
380 .to_vec(),
381 )?)
382 .upcast_arc_dyn();
383 Ok((eval_point, mle))
384 })
385 .collect::<Result<HashMap<_, _>, Error>>()?;
386
387 for (claim, meta) in iter::zip(composite_mle_claims, metas) {
388 let eq_ind = eq_indicators
389 .get(claim.eval_point.as_ref())
390 .expect("was added above");
391
392 witness_index.update_multilin_poly(vec![(meta.eq_ind_id, eq_ind.clone())])?;
393 }
394
395 Ok(())
396}
397
398#[allow(clippy::type_complexity)]
400pub struct MemoizedData<'a, P: PackedField, Backend: ComputationBackend> {
401 query: Vec<(Vec<P::Scalar>, MultilinearQuery<P, Backend::Vec<P>>)>,
402 partial_evals: EvalPointOracleIdMap<MultilinearWitness<'a, P>, P::Scalar>,
403}
404
405impl<'a, P: PackedField, Backend: ComputationBackend> MemoizedData<'a, P, Backend> {
406 #[allow(clippy::new_without_default)]
407 pub fn new() -> Self {
408 Self {
409 query: Vec::new(),
410 partial_evals: EvalPointOracleIdMap::new(),
411 }
412 }
413
414 pub fn full_query(
415 &mut self,
416 eval_point: &[P::Scalar],
417 backend: &Backend,
418 ) -> Result<&MultilinearQuery<P, Backend::Vec<P>>, Error> {
419 if let Some(index) = self
420 .query
421 .iter()
422 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
423 {
424 let (_, ref query) = &self.query[index];
425 return Ok(query);
426 }
427
428 let query = backend.multilinear_query(eval_point)?;
429 self.query.push((eval_point.to_vec(), query));
430
431 let (_, ref query) = self.query.last().expect("pushed query immediately above");
432 Ok(query)
433 }
434
435 pub fn full_query_readonly(
437 &self,
438 eval_point: &[P::Scalar],
439 ) -> Option<&MultilinearQuery<P, Backend::Vec<P>>> {
440 self.query
441 .iter()
442 .position(|(memo_eval_point, _)| memo_eval_point.as_slice() == eval_point)
443 .map(|index| {
444 let (_, ref query) = &self.query[index];
445 query
446 })
447 }
448
449 #[instrument(skip_all, name = "Evalcheck::memoize_query_par", level = "debug")]
450 pub fn memoize_query_par<'b>(
451 &mut self,
452 eval_points: impl IntoIterator<Item = &'b [P::Scalar]>,
453 backend: &Backend,
454 ) -> Result<(), binius_hal::Error> {
455 let deduplicated_eval_points = eval_points.into_iter().collect::<HashSet<_>>();
456
457 let new_queries = deduplicated_eval_points
458 .into_par_iter()
459 .filter(|ep| self.full_query_readonly(ep).is_none())
460 .map(|ep| {
461 backend
462 .multilinear_query::<P>(ep)
463 .map(|res| (ep.to_vec(), res))
464 })
465 .collect::<Result<Vec<_>, binius_hal::Error>>()?;
466
467 self.query.extend(new_queries);
468
469 Ok(())
470 }
471
472 pub fn memoize_partial_evals(
473 &mut self,
474 metas: &[ProjectedBivariateMeta],
475 projected_bivariate_claims: &[EvalcheckMultilinearClaim<P::Scalar>],
476 oracles: &mut MultilinearOracleSet<P::Scalar>,
477 witness_index: &MultilinearExtensionIndex<'a, P>,
478 ) where
479 P::Scalar: TowerField,
480 {
481 projected_bivariate_claims
482 .iter()
483 .zip(metas)
484 .for_each(|(claim, meta)| {
485 let inner_id = meta.inner_id;
486 if matches!(oracles.oracle(inner_id).variant, MultilinearPolyVariant::Committed)
487 && meta.projected_id.is_some()
488 {
489 let eval_point = claim.eval_point[meta.projected_n_vars..].to_vec().into();
490
491 let projected_id = meta.projected_id.expect("checked above");
492
493 let projected = witness_index
494 .get_multilin_poly(projected_id)
495 .expect("witness_index contains projected if projected_id exist");
496
497 self.partial_evals.insert(inner_id, eval_point, projected);
498 }
499 });
500 }
501
502 pub fn partial_eval(
503 &self,
504 id: OracleId,
505 eval_point: &[P::Scalar],
506 ) -> Option<&MultilinearWitness<'a, P>> {
507 self.partial_evals.get(id, eval_point)
508 }
509}
510
511type SumcheckProofEvalcheckClaims<F> = Vec<EvalcheckMultilinearClaim<F>>;
512
513pub fn prove_bivariate_sumchecks_with_switchover<F, P, DomainField, Transcript, Backend>(
514 witness: &MultilinearExtensionIndex<P>,
515 constraint_sets: Vec<ConstraintSet<F>>,
516 transcript: &mut ProverTranscript<Transcript>,
517 switchover_fn: impl Fn(usize) -> usize + 'static,
518 domain_factory: impl EvaluationDomainFactory<DomainField>,
519 backend: &Backend,
520) -> Result<SumcheckProofEvalcheckClaims<F>, SumcheckError>
521where
522 P: PackedField<Scalar = F>
523 + PackedExtension<F, PackedSubfield = P>
524 + PackedExtension<DomainField>,
525 F: TowerField + ExtensionField<DomainField>,
526 DomainField: Field,
527 Transcript: Challenger,
528 Backend: ComputationBackend,
529{
530 let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas(
531 EvaluationOrder::HighToLow,
532 constraint_sets,
533 witness,
534 domain_factory,
535 &switchover_fn,
536 backend,
537 )?;
538
539 let batch_prover = front_loaded::BatchProver::new(provers, transcript)?;
540
541 let mut sumcheck_output = batch_prover.run(transcript)?;
542
543 sumcheck_output.challenges.reverse();
545
546 let evalcheck_claims =
547 sumcheck::make_eval_claims(EvaluationOrder::HighToLow, metas, sumcheck_output)?;
548
549 Ok(evalcheck_claims)
550}
551
552#[derive(Clone)]
553pub enum SumcheckClaims<F: Field> {
554 Projected(EvalcheckMultilinearClaim<F>),
555 Composite(EvalcheckMultilinearClaim<F>),
556}