binius_core/protocols/sumcheck/
common.rs1use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6 util::{inner_product_unchecked, powers},
7 ExtensionField, Field, PackedField,
8};
9use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15#[derive(Debug, Clone, Getters, CopyGetters)]
23pub struct CompositeSumClaim<F: Field, Composition> {
24 pub composition: Composition,
25 pub sum: F,
26}
27
28#[derive(Debug, Clone, CopyGetters)]
37pub struct SumcheckClaim<F: Field, C> {
38 #[getset(get_copy = "pub")]
39 n_vars: usize,
40 #[getset(get_copy = "pub")]
41 n_multilinears: usize,
42 composite_sums: Vec<CompositeSumClaim<F, C>>,
43}
44
45impl<F: Field, Composition> SumcheckClaim<F, Composition>
46where
47 Composition: CompositionPoly<F>,
48{
49 pub fn new(
56 n_vars: usize,
57 n_multilinears: usize,
58 composite_sums: Vec<CompositeSumClaim<F, Composition>>,
59 ) -> Result<Self, Error> {
60 for CompositeSumClaim {
61 ref composition, ..
62 } in &composite_sums
63 {
64 if composition.n_vars() != n_multilinears {
65 bail!(Error::InvalidComposition {
66 actual: composition.n_vars(),
67 expected: n_multilinears,
68 });
69 }
70 }
71 Ok(Self {
72 n_vars,
73 n_multilinears,
74 composite_sums,
75 })
76 }
77
78 pub fn max_individual_degree(&self) -> usize {
80 self.composite_sums
81 .iter()
82 .map(|composite_sum| composite_sum.composition.degree())
83 .max()
84 .unwrap_or(0)
85 }
86
87 pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
88 &self.composite_sums
89 }
90}
91
92#[derive(Debug, Default, Clone, PartialEq, Eq)]
96pub struct RoundCoeffs<F: Field>(pub Vec<F>);
97
98impl<F: Field> RoundCoeffs<F> {
99 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
101 RoundCoeffs(self.0.into_iter().map(Into::into).collect())
102 }
103
104 pub fn truncate(mut self) -> RoundProof<F> {
106 self.0.pop();
107 RoundProof(self)
108 }
109}
110
111impl<F: Field> Add<&Self> for RoundCoeffs<F> {
112 type Output = Self;
113
114 fn add(mut self, rhs: &Self) -> Self::Output {
115 self += rhs;
116 self
117 }
118}
119
120impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
121 fn add_assign(&mut self, rhs: &Self) {
122 if self.0.len() < rhs.0.len() {
123 self.0.resize(rhs.0.len(), F::ZERO);
124 }
125
126 for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
127 *lhs_i += rhs_i;
128 }
129 }
130}
131
132impl<F: Field> Mul<F> for RoundCoeffs<F> {
133 type Output = Self;
134
135 fn mul(mut self, rhs: F) -> Self::Output {
136 self *= rhs;
137 self
138 }
139}
140
141impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
142 fn mul_assign(&mut self, rhs: F) {
143 for coeff in &mut self.0 {
144 *coeff *= rhs;
145 }
146 }
147}
148
149#[derive(Debug, Default, Clone, PartialEq, Eq)]
156pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
157
158impl<F: Field> RoundProof<F> {
159 pub fn recover(self, sum: F) -> RoundCoeffs<F> {
180 let Self(RoundCoeffs(mut coeffs)) = self;
181 let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
182 let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
183 coeffs.push(last_coeff);
184 RoundCoeffs(coeffs)
185 }
186
187 pub fn coeffs(&self) -> &[F] {
189 &self.0 .0
190 }
191
192 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
194 RoundProof(self.0.isomorphic())
195 }
196}
197
198#[derive(Debug, Default, Clone, PartialEq, Eq)]
200pub struct Proof<F: Field> {
201 pub rounds: Vec<RoundProof<F>>,
203 pub multilinear_evals: Vec<Vec<F>>,
210}
211
212#[derive(Debug, PartialEq, Eq)]
214pub struct BatchSumcheckOutput<F: Field> {
215 pub challenges: Vec<F>,
217 pub multilinear_evals: Vec<Vec<F>>,
220}
221
222impl<F: Field> BatchSumcheckOutput<F> {
223 pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
224 BatchSumcheckOutput {
225 challenges: self.challenges.into_iter().map(Into::into).collect(),
226 multilinear_evals: self
227 .multilinear_evals
228 .into_iter()
229 .map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
230 .collect(),
231 }
232 }
233}
234
235pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
238 move |extension_degree: usize| {
239 ((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
240 }
241}
242
243pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
245 0
246}
247
248pub fn equal_n_vars_check<'a, P, M>(
250 multilinears: impl IntoIterator<Item = &'a M>,
251) -> Result<usize, Error>
252where
253 P: PackedField,
254 M: MultilinearPoly<P> + 'a,
255{
256 let mut multilinears = multilinears.into_iter();
257 let n_vars = multilinears
258 .next()
259 .map(|multilinear| multilinear.n_vars())
260 .unwrap_or_default();
261 for multilinear in multilinears {
262 if multilinear.n_vars() != n_vars {
263 bail!(Error::NumberOfVariablesMismatch);
264 }
265 }
266 Ok(n_vars)
267}
268
269pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
274where
275 F: Field + ExtensionField<FBase>,
276 FBase: Field,
277 P: PackedField<Scalar = F>,
278 M: MultilinearPoly<P>,
279{
280 for multilinear in multilinears {
281 if multilinear.log_extension_degree() < F::LOG_DEGREE {
282 bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
283 }
284 }
285
286 Ok(())
287}
288
289pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
291 batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
293}
294
295pub fn interpolation_domains_for_composition_degrees<FDomain>(
297 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
298 degrees: impl IntoIterator<Item = usize>,
299) -> Result<Vec<InterpolationDomain<FDomain>>, Error>
300where
301 FDomain: Field,
302{
303 degrees
304 .into_iter()
305 .map(|degree| Ok(evaluation_domain_factory.create(degree + 1)?.into()))
306 .collect()
307}
308
309pub fn get_nontrivial_evaluation_points<F: Field>(
314 domains: &[InterpolationDomain<F>],
315) -> Result<Vec<F>, Error> {
316 let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
317 return Ok(Vec::new());
318 };
319
320 #[allow(clippy::get_first)]
321 if !domains.iter().all(|domain| {
322 (domain.size() <= 2 || domain.with_infinity())
323 && domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
324 && domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
325 }) {
326 bail!(Error::IncorrectSumcheckEvaluationDomain);
327 }
328
329 let finite_points = largest_domain.finite_points();
330
331 if domains
332 .iter()
333 .any(|domain| !finite_points.starts_with(domain.finite_points()))
334 {
335 bail!(Error::NonProperPrefixEvaluationDomain);
336 }
337
338 let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
339 Ok(nontrivial_evaluation_points)
340}
341
342#[cfg(test)]
343mod tests {
344 use binius_field::BinaryField64b;
345
346 use super::*;
347
348 type F = BinaryField64b;
349
350 #[test]
351 fn test_round_coeffs_truncate_non_empty() {
352 let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
353 let truncated = coeffs.truncate();
354 assert_eq!(truncated.0 .0, vec![F::from(1), F::from(2)]);
355 }
356
357 #[test]
358 fn test_round_coeffs_truncate_empty() {
359 let coeffs = RoundCoeffs::<F>(vec![]);
360 let truncated = coeffs.truncate();
361 assert!(truncated.0 .0.is_empty());
362 }
363}