binius_core/protocols/sumcheck/
common.rs1use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6 ExtensionField, Field, PackedField,
7 util::{inner_product_unchecked, powers},
8};
9use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15#[derive(Debug, Clone, Getters, CopyGetters)]
23pub struct CompositeSumClaim<F: Field, Composition> {
24 pub composition: Composition,
25 pub sum: F,
26}
27
28#[derive(Debug, Clone, CopyGetters)]
37pub struct SumcheckClaim<F: Field, C> {
38 #[getset(get_copy = "pub")]
39 n_vars: usize,
40 #[getset(get_copy = "pub")]
41 n_multilinears: usize,
42 composite_sums: Vec<CompositeSumClaim<F, C>>,
43}
44
45impl<F: Field, Composition> SumcheckClaim<F, Composition>
46where
47 Composition: CompositionPoly<F>,
48{
49 pub fn new(
56 n_vars: usize,
57 n_multilinears: usize,
58 composite_sums: Vec<CompositeSumClaim<F, Composition>>,
59 ) -> Result<Self, Error> {
60 for CompositeSumClaim { composition, .. } in &composite_sums {
61 if composition.n_vars() != n_multilinears {
62 bail!(Error::InvalidComposition {
63 actual: composition.n_vars(),
64 expected: n_multilinears,
65 });
66 }
67 }
68 Ok(Self {
69 n_vars,
70 n_multilinears,
71 composite_sums,
72 })
73 }
74
75 pub fn max_individual_degree(&self) -> usize {
77 self.composite_sums
78 .iter()
79 .map(|composite_sum| composite_sum.composition.degree())
80 .max()
81 .unwrap_or(0)
82 }
83
84 pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
85 &self.composite_sums
86 }
87}
88
89#[derive(Debug, Default, Clone, PartialEq, Eq)]
93pub struct RoundCoeffs<F: Field>(pub Vec<F>);
94
95impl<F: Field> RoundCoeffs<F> {
96 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
98 RoundCoeffs(self.0.into_iter().map(Into::into).collect())
99 }
100
101 pub fn truncate(mut self) -> RoundProof<F> {
103 self.0.pop();
104 RoundProof(self)
105 }
106}
107
108impl<F: Field> Add<&Self> for RoundCoeffs<F> {
109 type Output = Self;
110
111 fn add(mut self, rhs: &Self) -> Self::Output {
112 self += rhs;
113 self
114 }
115}
116
117impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
118 fn add_assign(&mut self, rhs: &Self) {
119 if self.0.len() < rhs.0.len() {
120 self.0.resize(rhs.0.len(), F::ZERO);
121 }
122
123 for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
124 *lhs_i += rhs_i;
125 }
126 }
127}
128
129impl<F: Field> Mul<F> for RoundCoeffs<F> {
130 type Output = Self;
131
132 fn mul(mut self, rhs: F) -> Self::Output {
133 self *= rhs;
134 self
135 }
136}
137
138impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
139 fn mul_assign(&mut self, rhs: F) {
140 for coeff in &mut self.0 {
141 *coeff *= rhs;
142 }
143 }
144}
145
146#[derive(Debug, Default, Clone, PartialEq, Eq)]
153pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
154
155impl<F: Field> RoundProof<F> {
156 pub fn recover(self, sum: F) -> RoundCoeffs<F> {
177 let Self(RoundCoeffs(mut coeffs)) = self;
178 let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
179 let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
180 coeffs.push(last_coeff);
181 RoundCoeffs(coeffs)
182 }
183
184 pub fn coeffs(&self) -> &[F] {
186 &self.0.0
187 }
188
189 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
191 RoundProof(self.0.isomorphic())
192 }
193}
194
195#[derive(Debug, Default, Clone, PartialEq, Eq)]
197pub struct Proof<F: Field> {
198 pub rounds: Vec<RoundProof<F>>,
200 pub multilinear_evals: Vec<Vec<F>>,
207}
208
209#[derive(Debug, PartialEq, Eq)]
211pub struct BatchSumcheckOutput<F: Field> {
212 pub challenges: Vec<F>,
214 pub multilinear_evals: Vec<Vec<F>>,
217}
218
219impl<F: Field> BatchSumcheckOutput<F> {
220 pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
221 BatchSumcheckOutput {
222 challenges: self.challenges.into_iter().map(Into::into).collect(),
223 multilinear_evals: self
224 .multilinear_evals
225 .into_iter()
226 .map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
227 .collect(),
228 }
229 }
230}
231
232pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
235 move |extension_degree: usize| {
236 ((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
237 }
238}
239
240pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
242 0
243}
244
245pub fn equal_n_vars_check<'a, P, M>(
247 multilinears: impl IntoIterator<Item = &'a M>,
248) -> Result<usize, Error>
249where
250 P: PackedField,
251 M: MultilinearPoly<P> + 'a,
252{
253 let mut multilinears = multilinears.into_iter();
254 let n_vars = multilinears
255 .next()
256 .map(|multilinear| multilinear.n_vars())
257 .unwrap_or_default();
258 for multilinear in multilinears {
259 if multilinear.n_vars() != n_vars {
260 bail!(Error::NumberOfVariablesMismatch);
261 }
262 }
263 Ok(n_vars)
264}
265
266pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
271where
272 F: Field + ExtensionField<FBase>,
273 FBase: Field,
274 P: PackedField<Scalar = F>,
275 M: MultilinearPoly<P>,
276{
277 for multilinear in multilinears {
278 if multilinear.log_extension_degree() < F::LOG_DEGREE {
279 bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
280 }
281 }
282
283 Ok(())
284}
285
286pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
288 batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
290}
291
292pub fn interpolation_domains_for_composition_degrees<FDomain>(
294 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
295 degrees: impl IntoIterator<Item = usize>,
296) -> Result<Vec<InterpolationDomain<FDomain>>, Error>
297where
298 FDomain: Field,
299{
300 degrees
301 .into_iter()
302 .map(|degree| Ok(evaluation_domain_factory.create(degree + 1)?.into()))
303 .collect()
304}
305
306pub fn get_nontrivial_evaluation_points<F: Field>(
311 domains: &[InterpolationDomain<F>],
312) -> Result<Vec<F>, Error> {
313 let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
314 return Ok(Vec::new());
315 };
316
317 #[allow(clippy::get_first)]
318 if !domains.iter().all(|domain| {
319 (domain.size() <= 2 || domain.with_infinity())
320 && domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
321 && domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
322 }) {
323 bail!(Error::IncorrectSumcheckEvaluationDomain);
324 }
325
326 let finite_points = largest_domain.finite_points();
327
328 if domains
329 .iter()
330 .any(|domain| !finite_points.starts_with(domain.finite_points()))
331 {
332 bail!(Error::NonProperPrefixEvaluationDomain);
333 }
334
335 let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
336 Ok(nontrivial_evaluation_points)
337}
338
339#[cfg(test)]
340mod tests {
341 use binius_field::BinaryField64b;
342
343 use super::*;
344
345 type F = BinaryField64b;
346
347 #[test]
348 fn test_round_coeffs_truncate_non_empty() {
349 let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
350 let truncated = coeffs.truncate();
351 assert_eq!(truncated.0.0, vec![F::from(1), F::from(2)]);
352 }
353
354 #[test]
355 fn test_round_coeffs_truncate_empty() {
356 let coeffs = RoundCoeffs::<F>(vec![]);
357 let truncated = coeffs.truncate();
358 assert!(truncated.0.0.is_empty());
359 }
360}