use super::Error;
use crate::{
polynomial::{
Error as PolynomialError, MultilinearExtensionSpecialized, MultilinearPoly,
MultilinearQuery, MultilinearQueryRef,
},
protocols::utils::deinterleave,
};
use binius_field::PackedField;
use binius_hal::ComputationBackend;
use binius_utils::{array_2d::Array2D, bail};
use rayon::prelude::*;
use std::{cmp::max, collections::HashMap, fmt::Debug, hash::Hash};
#[derive(Debug, Clone)]
enum SumcheckMultilinear<P, M>
where
P: PackedField,
M: MultilinearPoly<P> + Send + Sync,
{
Transparent {
multilinear: M,
introduction_round: usize,
switchover_round: usize,
},
Folded {
large_field_folded_multilinear: MultilinearExtensionSpecialized<P, P>,
},
}
struct ParFoldStates<P: PackedField> {
evals_0: Array2D<P>,
evals_1: Array2D<P>,
evals_z: Array2D<P>,
interleaved_evals: Vec<P>,
round_evals: Array2D<P>,
}
impl<P: PackedField> ParFoldStates<P> {
fn new(n_multilinears: usize, n_round_evals: usize, subcube_vars: usize) -> Self {
let n_states = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
let n_interleaved = 1 << (subcube_vars + 1).saturating_sub(P::LOG_WIDTH);
Self {
evals_0: Array2D::zeroes(n_states, n_multilinears),
evals_1: Array2D::zeroes(n_states, n_multilinears),
evals_z: Array2D::zeroes(n_states, n_multilinears),
interleaved_evals: vec![P::default(); n_interleaved],
round_evals: Array2D::zeroes(n_states, n_round_evals),
}
}
}
pub trait AbstractSumcheckEvaluator<P: PackedField>: Sync {
type VertexState;
fn n_round_evals(&self) -> usize;
fn process_vertex(
&self,
i: usize,
vertex_state: Self::VertexState,
evals_0: &[P],
evals_1: &[P],
evals_z: &mut [P],
round_evals: &mut [P],
);
fn round_evals_to_coeffs(
&self,
current_round_sum: P::Scalar,
round_evals: Vec<P::Scalar>,
) -> Result<Vec<P::Scalar>, PolynomialError>;
}
pub struct CommonProversState<MultilinearId, PW, M, Backend>
where
MultilinearId: Hash + Eq + Sync,
PW: PackedField,
M: MultilinearPoly<PW> + Send + Sync,
Backend: ComputationBackend,
{
n_vars: usize,
switchover_fn: Box<dyn Fn(usize) -> usize>,
next_round: usize,
multilinears: HashMap<MultilinearId, SumcheckMultilinear<PW, M>>,
max_query_vars: Option<usize>,
queries: Vec<Option<MultilinearQuery<PW, Backend>>>,
backend: Backend,
}
impl<MultilinearId, PW, M, Backend> CommonProversState<MultilinearId, PW, M, Backend>
where
MultilinearId: Clone + Hash + Eq + Sync + Debug,
PW: PackedField,
M: MultilinearPoly<PW> + Sync + Send,
Backend: ComputationBackend,
{
pub fn new(
n_vars: usize,
switchover_fn: impl Fn(usize) -> usize + 'static,
backend: Backend,
) -> Self {
Self {
n_vars,
switchover_fn: Box::new(switchover_fn),
next_round: 0,
multilinears: HashMap::new(),
max_query_vars: None,
queries: Vec::new(),
backend,
}
}
pub fn extend(
&mut self,
multilinears: impl IntoIterator<Item = (MultilinearId, M)>,
) -> Result<(), Error> {
let introduction_round = self.next_round;
for (multilinear_id, multilinear) in multilinears {
let switchover_round = max(1, (self.switchover_fn)(multilinear.extension_degree()));
self.max_query_vars = Some(max(self.max_query_vars.unwrap_or(1), switchover_round));
if introduction_round + multilinear.n_vars() != self.n_vars {
bail!(PolynomialError::IncorrectNumberOfVariables {
expected: self.n_vars - introduction_round,
actual: multilinear.n_vars(),
});
}
self.multilinears.insert(
multilinear_id.clone(),
SumcheckMultilinear::Transparent {
multilinear,
introduction_round,
switchover_round,
},
);
}
Ok(())
}
pub fn pre_execute_rounds(
&mut self,
prev_rd_challenge: Option<PW::Scalar>,
) -> Result<(), PolynomialError> {
assert_eq!(self.next_round == 0, prev_rd_challenge.is_none());
if let Some(prev_rd_challenge) = prev_rd_challenge {
for query in self.queries.iter_mut() {
if let Some(prev_query) = query.take() {
let expanded_query = prev_query.update(&[prev_rd_challenge])?;
query.replace(expanded_query);
}
}
}
let new_query = self
.max_query_vars
.take()
.map(|max_query_vars| MultilinearQuery::new(max_query_vars))
.transpose()?;
self.queries.push(new_query);
self.next_round += 1;
debug_assert_eq!(self.next_round, self.queries.len());
if let Some(prev_rd_challenge) = prev_rd_challenge {
self.fold(prev_rd_challenge)?;
}
Ok(())
}
fn fold(&mut self, prev_rd_challenge: PW::Scalar) -> Result<(), PolynomialError> {
let &mut Self {
ref mut multilinears,
ref mut queries,
next_round,
..
} = self;
let single_variable_partial_query =
MultilinearQuery::with_full_query(&[prev_rd_challenge], self.backend.clone())?;
let any_transparent_left = multilinears
.par_iter_mut()
.map(|(_, sc_multilinear)| -> Result<Option<usize>, PolynomialError> {
match sc_multilinear {
&mut SumcheckMultilinear::Transparent {
ref mut multilinear,
introduction_round,
switchover_round,
} => {
if switchover_round + introduction_round < next_round {
let query_ref = queries[introduction_round].as_ref().expect(
"query is guaranteed to be Some while there are transparent \
multilinears remaining",
);
let large_field_folded_multilinear =
multilinear.evaluate_partial_low(query_ref.to_ref())?;
*sc_multilinear = SumcheckMultilinear::Folded {
large_field_folded_multilinear,
};
Ok(None)
} else {
Ok(Some(introduction_round))
}
}
SumcheckMultilinear::Folded {
ref mut large_field_folded_multilinear,
} => {
*large_field_folded_multilinear = large_field_folded_multilinear
.evaluate_partial_low(single_variable_partial_query.to_ref())?;
Ok(None)
}
}
})
.try_fold(
|| vec![false; queries.len()],
|mut any_transparent_left,
opt_round: Result<Option<usize>, PolynomialError>|
-> Result<Vec<bool>, PolynomialError> {
if let Some(round) = opt_round? {
any_transparent_left[round] = true;
}
Ok(any_transparent_left)
},
)
.try_reduce(
|| vec![false; queries.len()],
|mut any_transparent_lhs, any_transparent_rhs| {
any_transparent_lhs
.iter_mut()
.zip(any_transparent_rhs)
.for_each(|(lhs, rhs)| *lhs |= rhs);
Ok(any_transparent_lhs)
},
)?;
for (query, keep) in queries.iter_mut().zip(any_transparent_left) {
if !keep {
*query = None;
}
}
Ok(())
}
pub fn calculate_round_coeffs<VS>(
&self,
multilinear_ids: &[MultilinearId],
evaluator: impl AbstractSumcheckEvaluator<PW, VertexState = VS>,
current_round_sum: PW::Scalar,
vertex_state_iterator: impl IndexedParallelIterator<Item = VS>,
) -> Result<Vec<PW::Scalar>, Error> {
assert!(
self.max_query_vars.is_none(),
"extend() called after pre_execute_rounds() but before calculate_round_coeffs()"
);
let &Self {
ref multilinears,
next_round,
..
} = self;
let mut any_transparent = false;
let mut any_folded = false;
for multilinear_id in multilinear_ids {
let multilinear = multilinears
.get(multilinear_id)
.ok_or(Error::WitnessNotFound)?;
match multilinear {
SumcheckMultilinear::Transparent { .. } => {
any_transparent = true;
}
SumcheckMultilinear::Folded { .. } => {
any_folded = true;
}
}
}
let opt_query = self.get_subset_query(multilinear_ids);
match (any_transparent, any_folded, opt_query) {
(true, false, Some((introduction_round, _)))
if introduction_round + 1 == next_round =>
{
self.calculate_round_coeffs_helper(
multilinear_ids,
Self::only_transparent,
Self::direct_sample,
evaluator,
vertex_state_iterator,
current_round_sum,
)
}
(true, false, Some((introduction_round, query)))
if introduction_round + 1 < next_round =>
{
self.calculate_round_coeffs_helper(
multilinear_ids,
Self::only_transparent,
|multilin, subcube_vars, subcube_index, evals| {
Self::subcube_inner_product(
multilin,
query.to_ref(),
subcube_vars,
subcube_index,
evals,
)
},
evaluator,
vertex_state_iterator,
current_round_sum,
)
}
(false, true, _) => self.calculate_round_coeffs_helper(
multilinear_ids,
Self::only_folded,
Self::direct_sample,
evaluator,
vertex_state_iterator,
current_round_sum,
),
(_, _, Some((_, query))) => self.calculate_round_coeffs_helper(
multilinear_ids,
|x| x,
|sc_multilin, subcube_vars, subcube_index, evals| match sc_multilin {
SumcheckMultilinear::Transparent { multilinear, .. } => {
Self::subcube_inner_product(
multilinear,
query.to_ref(),
subcube_vars,
subcube_index,
evals,
)
}
SumcheckMultilinear::Folded {
large_field_folded_multilinear,
} => Self::direct_sample(
large_field_folded_multilinear,
subcube_vars,
subcube_index,
evals,
),
},
evaluator,
vertex_state_iterator,
current_round_sum,
),
_ => panic!("tensor not present during sumcheck, or some other invalid case"),
}
}
fn calculate_round_coeffs_helper<'b, T, VS>(
&'b self,
multilinear_ids: &[MultilinearId],
precomp: impl Fn(&'b SumcheckMultilinear<PW, M>) -> T,
eval01: impl Fn(T, usize, usize, &mut [PW]) + Sync,
evaluator: impl AbstractSumcheckEvaluator<PW, VertexState = VS>,
vertex_state_iterator: impl IndexedParallelIterator<Item = VS>,
current_round_sum: PW::Scalar,
) -> Result<Vec<PW::Scalar>, Error>
where
T: Copy + Sync + 'b,
M: 'b,
{
let precomps = multilinear_ids
.iter()
.map(|multilinear_id| -> Result<T, Error> {
let sc_multilinear = self
.multilinears
.get(multilinear_id)
.ok_or(Error::WitnessNotFound)?;
Ok(precomp(sc_multilinear))
})
.collect::<Result<Vec<_>, _>>()?;
let n_vars = self.n_vars - self.next_round + 1;
let n_multilinears = precomps.len();
let n_round_evals = evaluator.n_round_evals();
const MAX_SUBCUBE_VARS: usize = 5;
let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars) - 1;
let evals = vertex_state_iterator
.chunks(1 << subcube_vars.saturating_sub(PW::LOG_WIDTH))
.enumerate()
.fold(
|| ParFoldStates::new(n_multilinears, n_round_evals, subcube_vars),
|mut par_fold_states, (subcube_index, vertex_states)| {
for (j, precomp) in precomps.iter().enumerate() {
eval01(
*precomp,
subcube_vars + 1,
subcube_index,
&mut par_fold_states.interleaved_evals,
);
deinterleave(subcube_vars, &par_fold_states.interleaved_evals).for_each(
|(i, even, odd)| {
par_fold_states.evals_0[(i, j)] = even;
par_fold_states.evals_1[(i, j)] = odd;
},
);
}
let subcube_start = subcube_index * vertex_states.len();
for (k, vertex_state) in vertex_states.into_iter().enumerate() {
evaluator.process_vertex(
subcube_start + k,
vertex_state,
par_fold_states.evals_0.get_row(k),
par_fold_states.evals_1.get_row(k),
par_fold_states.evals_z.get_row_mut(k),
par_fold_states.round_evals.get_row_mut(k),
);
}
par_fold_states
},
)
.map(|states| states.round_evals.sum_rows())
.reduce(
|| vec![PW::zero(); n_round_evals],
|mut overall_round_evals, partial_round_evals| {
overall_round_evals
.iter_mut()
.zip(partial_round_evals.iter())
.for_each(|(f, s)| *f += *s);
overall_round_evals
},
)
.iter()
.map(|x| x.iter().sum())
.collect();
Ok(evaluator.round_evals_to_coeffs(current_round_sum, evals)?)
}
pub(crate) fn get_subset_query(
&self,
oracle_ids: &[MultilinearId],
) -> Option<(usize, &MultilinearQuery<PW, Backend>)> {
let introduction_rounds = oracle_ids
.iter()
.flat_map(|oracle_id| match self.multilinears.get(oracle_id)? {
SumcheckMultilinear::Transparent {
introduction_round, ..
} => Some(*introduction_round),
_ => None,
})
.collect::<Vec<_>>();
let first_introduction_round = *introduction_rounds.first()?;
if introduction_rounds
.iter()
.any(|&round| round != first_introduction_round)
{
return None;
}
self.queries
.get(first_introduction_round)?
.as_ref()
.map(|query| (first_introduction_round, query))
}
#[inline]
fn direct_sample<MD: MultilinearPoly<PW>>(
multilin: MD,
subcube_vars: usize,
subcube_index: usize,
evals: &mut [PW],
) {
multilin
.subcube_evals(subcube_vars, subcube_index, evals)
.expect("indices within range");
}
#[inline]
fn subcube_inner_product(
multilin: &M,
query: MultilinearQueryRef<PW>,
subcube_vars: usize,
subcube_index: usize,
inner_products: &mut [PW],
) {
multilin
.subcube_inner_products(query, subcube_vars, subcube_index, inner_products)
.expect("indices within range");
}
fn only_transparent(sc_multilin: &SumcheckMultilinear<PW, M>) -> &M {
match sc_multilin {
SumcheckMultilinear::Transparent { multilinear, .. } => multilinear,
_ => panic!("all transparent by invariant"),
}
}
fn only_folded(
sc_multilin: &SumcheckMultilinear<PW, M>,
) -> &MultilinearExtensionSpecialized<PW, PW> {
match sc_multilin {
SumcheckMultilinear::Folded {
large_field_folded_multilinear,
} => large_field_folded_multilinear,
_ => panic!("all folded by invariant"),
}
}
}