binius_core/protocols/sumcheck/
common.rs1use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6 util::{inner_product_unchecked, powers},
7 ExtensionField, Field, PackedField,
8};
9use binius_math::{CompositionPoly, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12use tracing::instrument;
13
14use super::error::Error;
15
16#[derive(Debug, Clone, Getters, CopyGetters)]
24pub struct CompositeSumClaim<F: Field, Composition> {
25 pub composition: Composition,
26 pub sum: F,
27}
28
29#[derive(Debug, Clone, CopyGetters)]
38pub struct SumcheckClaim<F: Field, C> {
39 #[getset(get_copy = "pub")]
40 n_vars: usize,
41 #[getset(get_copy = "pub")]
42 n_multilinears: usize,
43 composite_sums: Vec<CompositeSumClaim<F, C>>,
44}
45
46impl<F: Field, Composition> SumcheckClaim<F, Composition>
47where
48 Composition: CompositionPoly<F>,
49{
50 pub fn new(
57 n_vars: usize,
58 n_multilinears: usize,
59 composite_sums: Vec<CompositeSumClaim<F, Composition>>,
60 ) -> Result<Self, Error> {
61 for CompositeSumClaim {
62 ref composition, ..
63 } in &composite_sums
64 {
65 if composition.n_vars() != n_multilinears {
66 bail!(Error::InvalidComposition {
67 actual: composition.n_vars(),
68 expected: n_multilinears,
69 });
70 }
71 }
72 Ok(Self {
73 n_vars,
74 n_multilinears,
75 composite_sums,
76 })
77 }
78
79 pub fn max_individual_degree(&self) -> usize {
81 self.composite_sums
82 .iter()
83 .map(|composite_sum| composite_sum.composition.degree())
84 .max()
85 .unwrap_or(0)
86 }
87
88 pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
89 &self.composite_sums
90 }
91}
92
93#[derive(Debug, Default, Clone, PartialEq, Eq)]
97pub struct RoundCoeffs<F: Field>(pub Vec<F>);
98
99impl<F: Field> RoundCoeffs<F> {
100 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
102 RoundCoeffs(self.0.into_iter().map(Into::into).collect())
103 }
104
105 pub fn truncate(mut self) -> RoundProof<F> {
107 self.0.pop();
108 RoundProof(self)
109 }
110}
111
112impl<F: Field> Add<&Self> for RoundCoeffs<F> {
113 type Output = Self;
114
115 fn add(mut self, rhs: &Self) -> Self::Output {
116 self += rhs;
117 self
118 }
119}
120
121impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
122 fn add_assign(&mut self, rhs: &Self) {
123 if self.0.len() < rhs.0.len() {
124 self.0.resize(rhs.0.len(), F::ZERO);
125 }
126
127 for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
128 *lhs_i += rhs_i;
129 }
130 }
131}
132
133impl<F: Field> Mul<F> for RoundCoeffs<F> {
134 type Output = Self;
135
136 fn mul(mut self, rhs: F) -> Self::Output {
137 self *= rhs;
138 self
139 }
140}
141
142impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
143 fn mul_assign(&mut self, rhs: F) {
144 for coeff in &mut self.0 {
145 *coeff *= rhs;
146 }
147 }
148}
149
150#[derive(Debug, Default, Clone, PartialEq, Eq)]
157pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
158
159impl<F: Field> RoundProof<F> {
160 pub fn recover(self, sum: F) -> RoundCoeffs<F> {
181 let Self(RoundCoeffs(mut coeffs)) = self;
182 let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
183 let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
184 coeffs.push(last_coeff);
185 RoundCoeffs(coeffs)
186 }
187
188 pub fn coeffs(&self) -> &[F] {
190 &self.0 .0
191 }
192
193 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
195 RoundProof(self.0.isomorphic())
196 }
197}
198
199#[derive(Debug, Default, Clone, PartialEq, Eq)]
201pub struct Proof<F: Field> {
202 pub rounds: Vec<RoundProof<F>>,
204 pub multilinear_evals: Vec<Vec<F>>,
211}
212
213#[derive(Debug, PartialEq, Eq)]
214pub struct BatchSumcheckOutput<F: Field> {
215 pub challenges: Vec<F>,
216 pub multilinear_evals: Vec<Vec<F>>,
217}
218
219impl<F: Field> BatchSumcheckOutput<F> {
220 pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
221 BatchSumcheckOutput {
222 challenges: self.challenges.into_iter().map(Into::into).collect(),
223 multilinear_evals: self
224 .multilinear_evals
225 .into_iter()
226 .map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
227 .collect(),
228 }
229 }
230}
231
232pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
235 move |extension_degree: usize| {
236 ((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
237 }
238}
239
240pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
242 0
243}
244
245#[instrument(skip_all, level = "debug")]
247pub fn determine_switchovers<P, M>(
248 multilinears: &[M],
249 switchover_fn: impl Fn(usize) -> usize,
250) -> Vec<usize>
251where
252 P: PackedField,
253 M: MultilinearPoly<P>,
254{
255 multilinears
257 .iter()
258 .map(|multilinear| switchover_fn(1 << multilinear.log_extension_degree()))
259 .collect()
260}
261
262pub fn equal_n_vars_check<P, M>(multilinears: &[M]) -> Result<usize, Error>
264where
265 P: PackedField,
266 M: MultilinearPoly<P>,
267{
268 let n_vars = multilinears
269 .first()
270 .map(|multilinear| multilinear.n_vars())
271 .unwrap_or_default();
272 for multilinear in multilinears {
273 if multilinear.n_vars() != n_vars {
274 bail!(Error::NumberOfVariablesMismatch);
275 }
276 }
277 Ok(n_vars)
278}
279
280pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
285where
286 F: Field + ExtensionField<FBase>,
287 FBase: Field,
288 P: PackedField<Scalar = F>,
289 M: MultilinearPoly<P>,
290{
291 for multilinear in multilinears {
292 if multilinear.log_extension_degree() < F::LOG_DEGREE {
293 bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
294 }
295 }
296
297 Ok(())
298}
299
300pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
302 batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
304}
305
306pub fn get_nontrivial_evaluation_points<F: Field>(
311 domains: &[InterpolationDomain<F>],
312) -> Result<Vec<F>, Error> {
313 let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
314 return Ok(Vec::new());
315 };
316
317 #[allow(clippy::get_first)]
318 if !domains.iter().all(|domain| {
319 (domain.size() <= 2 || domain.with_infinity())
320 && domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
321 && domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
322 }) {
323 bail!(Error::IncorrectSumcheckEvaluationDomain);
324 }
325
326 let finite_points = largest_domain.finite_points();
327
328 if domains
329 .iter()
330 .any(|domain| !finite_points.starts_with(domain.finite_points()))
331 {
332 bail!(Error::NonProperPrefixEvaluationDomain);
333 }
334
335 let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
336 Ok(nontrivial_evaluation_points)
337}
338
339#[cfg(test)]
340mod tests {
341 use binius_field::BinaryField64b;
342
343 use super::*;
344
345 type F = BinaryField64b;
346
347 #[test]
348 fn test_round_coeffs_truncate_non_empty() {
349 let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
350 let truncated = coeffs.truncate();
351 assert_eq!(truncated.0 .0, vec![F::from(1), F::from(2)]);
352 }
353
354 #[test]
355 fn test_round_coeffs_truncate_empty() {
356 let coeffs = RoundCoeffs::<F>(vec![]);
357 let truncated = coeffs.truncate();
358 assert!(truncated.0 .0.is_empty());
359 }
360}