binius_core/protocols/sumcheck/
front_loaded.rs1use std::{cmp, cmp::Ordering, collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_math::{evaluate_univariate, CompositionPoly};
7use binius_utils::sorting::is_sorted_ascending;
8use bytes::Buf;
9
10use super::{
11 common::batch_weighted_value,
12 error::{Error, VerificationError},
13 verify::compute_expected_batch_composite_evaluation_single_claim,
14 RoundCoeffs, RoundProof,
15};
16use crate::{
17 fiat_shamir::CanSample, protocols::sumcheck::SumcheckClaim, transcript::TranscriptReader,
18};
19
20#[derive(Debug)]
21enum CoeffsOrSums<F: Field> {
22 Coeffs(RoundCoeffs<F>),
23 Sum(F),
24}
25
26#[derive(Debug)]
54pub struct BatchVerifier<F: Field, C> {
55 claims: VecDeque<SumcheckClaimWithContext<F, C>>,
56 round: usize,
57 last_coeffs_or_sum: CoeffsOrSums<F>,
58}
59
60impl<F, C> BatchVerifier<F, C>
61where
62 F: TowerField,
63 C: CompositionPoly<F> + Clone,
64{
65 pub fn new<Transcript>(
73 claims: &[SumcheckClaim<F, C>],
74 transcript: &mut Transcript,
75 ) -> Result<Self, Error>
76 where
77 Transcript: CanSample<F>,
78 {
79 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
80 return Err(Error::ClaimsOutOfOrder);
81 }
82
83 let batch_coeffs = transcript.sample_vec(claims.len());
85
86 let sum = iter::zip(claims, &batch_coeffs)
88 .map(|(claim, &batch_coeff)| {
89 batch_weighted_value(
90 batch_coeff,
91 claim
92 .composite_sums()
93 .iter()
94 .map(|composite_claim| composite_claim.sum),
95 )
96 })
97 .sum();
98
99 let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
100 .map(|(claim, batch_coeff)| {
101 let degree = claim
102 .composite_sums()
103 .iter()
104 .map(|composite_claim| composite_claim.composition.degree())
105 .max()
106 .unwrap_or(0);
107 SumcheckClaimWithContext {
108 claim,
109 batch_coeff,
110 max_degree_remaining: degree,
111 }
112 })
113 .collect::<VecDeque<_>>();
114
115 for i in (0..claims.len()).rev().skip(1) {
117 claims[i].max_degree_remaining =
118 cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
119 }
120
121 Ok(Self {
122 claims,
123 round: 0,
124 last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
125 })
126 }
127
128 pub fn remaining_claims(&self) -> usize {
130 self.claims.len()
131 }
132
133 pub fn try_finish_claim<B>(
135 &mut self,
136 transcript: &mut TranscriptReader<B>,
137 ) -> Result<Option<Vec<F>>, Error>
138 where
139 B: Buf,
140 {
141 let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
142 return Ok(None);
143 };
144 let multilinear_evals = match claim.n_vars().cmp(&self.round) {
145 Ordering::Equal => {
146 let SumcheckClaimWithContext {
147 claim, batch_coeff, ..
148 } = self.claims.pop_front().expect("front returned Some");
149 let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
150
151 match self.last_coeffs_or_sum {
152 CoeffsOrSums::Coeffs(_) => {
153 return Err(Error::ExpectedFinishRound);
154 }
155 CoeffsOrSums::Sum(ref mut sum) => {
156 *sum -= compute_expected_batch_composite_evaluation_single_claim(
161 batch_coeff,
162 &claim,
163 &multilinear_evals,
164 )?;
165 }
166 }
167 Some(multilinear_evals)
168 }
169 Ordering::Less => {
170 unreachable!(
171 "round is incremented in finish_round; \
172 finish_round does not increment round until receive_round_proof is called; \
173 receive_round_proof errors unless the claim at the active index has enough \
174 variables"
175 );
176 }
177 Ordering::Greater => None,
178 };
179 Ok(multilinear_evals)
180 }
181
182 pub fn receive_round_proof<B>(
184 &mut self,
185 transcript: &mut TranscriptReader<B>,
186 ) -> Result<(), Error>
187 where
188 B: Buf,
189 {
190 match self.last_coeffs_or_sum {
191 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
192 CoeffsOrSums::Sum(sum) => {
193 let degree = match self.claims.front() {
194 Some(SumcheckClaimWithContext {
195 claim,
196 max_degree_remaining,
197 ..
198 }) => {
199 if claim.n_vars() == self.round {
201 return Err(Error::ExpectedFinishClaim);
202 }
203 *max_degree_remaining
204 }
205 None => 0,
206 };
207
208 let proof_vals = transcript.read_scalar_slice(degree)?;
209 let round_proof = RoundProof(RoundCoeffs(proof_vals));
210 self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
211 Ok(())
212 }
213 }
214 }
215
216 pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
218 match self.last_coeffs_or_sum {
219 CoeffsOrSums::Coeffs(ref coeffs) => {
220 let sum = evaluate_univariate(&coeffs.0, challenge);
221 self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
222 self.round += 1;
223 Ok(())
224 }
225 CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
226 }
227 }
228
229 pub fn finish(self) -> Result<(), Error> {
231 if !self.claims.is_empty() {
232 return Err(Error::ExpectedFinishRound);
233 }
234
235 match self.last_coeffs_or_sum {
236 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
237 CoeffsOrSums::Sum(sum) => {
238 if sum != F::ZERO {
239 return Err(VerificationError::IncorrectBatchEvaluation.into());
240 }
241 Ok(())
242 }
243 }
244 }
245}
246
247#[derive(Debug)]
248struct SumcheckClaimWithContext<F: Field, C> {
249 claim: SumcheckClaim<F, C>,
250 batch_coeff: F,
251 max_degree_remaining: usize,
252}