binius_core/protocols/sumcheck/prove/
front_loaded.rs1use std::{collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_utils::sorting::is_sorted_ascending;
7use bytes::BufMut;
8
9use super::{batch_sumcheck::SumcheckProver, logging::PIOPCompilerFoldData};
10use crate::{
11 fiat_shamir::{CanSample, Challenger},
12 protocols::sumcheck::{BatchSumcheckOutput, Error, RoundCoeffs},
13 transcript::{ProverTranscript, TranscriptWriter},
14};
15
16#[derive(Debug)]
33pub struct BatchProver<F: Field, Prover> {
34 provers: VecDeque<(Prover, F)>,
35 multilinear_evals: Vec<Vec<F>>,
36 round: usize,
37}
38
39impl<F, Prover> BatchProver<F, Prover>
40where
41 F: TowerField,
42 Prover: SumcheckProver<F>,
43{
44 pub fn new<Transcript>(provers: Vec<Prover>, transcript: &mut Transcript) -> Result<Self, Error>
52 where
53 Transcript: CanSample<F>,
54 {
55 let batch_coeffs = transcript.sample_vec(provers.len());
57
58 Self::new_prebatched(batch_coeffs, provers)
59 }
60
61 pub fn total_rounds(&self) -> usize
63 where
64 Prover: SumcheckProver<F>,
65 {
66 self.provers
67 .back()
68 .map(|(prover, _)| prover.n_vars())
69 .unwrap_or(0)
70 }
71
72 pub fn new_prebatched(batch_coeffs: Vec<F>, provers: Vec<Prover>) -> Result<Self, Error> {
79 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
80 return Err(Error::ClaimsOutOfOrder);
81 }
82
83 if batch_coeffs.len() != provers.len() {
84 return Err(Error::IncorrectNumberOfBatchCoeffs);
85 }
86
87 if let Some(first_prover) = provers.first() {
88 if provers
89 .iter()
90 .any(|prover| prover.evaluation_order() != first_prover.evaluation_order())
91 {
92 return Err(Error::InconsistentEvaluationOrder);
93 }
94 }
95
96 let provers = iter::zip(provers, batch_coeffs).collect();
97
98 Ok(Self {
99 provers,
100 multilinear_evals: Vec::new(),
101 round: 0,
102 })
103 }
104
105 fn finish_claim_provers<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
106 where
107 B: BufMut,
108 {
109 while let Some((prover, _)) = self.provers.front() {
110 if prover.n_vars() != self.round {
111 break;
112 }
113 let (prover, _) = self.provers.pop_front().expect("front returned Some");
114 let claim_multilinear_evals = Box::new(prover).finish()?;
115 transcript.write_scalar_slice(&claim_multilinear_evals);
116 self.multilinear_evals.push(claim_multilinear_evals);
117 }
118 Ok(())
119 }
120
121 pub fn send_round_proof<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
123 where
124 B: BufMut,
125 {
126 self.finish_claim_provers(transcript)?;
127
128 let mut round_coeffs = RoundCoeffs::default();
129 for (prover, batch_coeff) in &mut self.provers {
130 let prover_coeffs = prover.execute(*batch_coeff)?;
131 round_coeffs += &(prover_coeffs * *batch_coeff);
132 }
133
134 let round_proof = round_coeffs.truncate();
135 transcript.write_scalar_slice(round_proof.coeffs());
136 Ok(())
137 }
138
139 pub fn receive_challenge(&mut self, challenge: F) -> Result<(), Error> {
141 for (prover, _) in &mut self.provers {
142 let dimensions_data = PIOPCompilerFoldData::new(prover);
143 let _span = tracing::debug_span!(
144 "[task] (PIOP Compiler) Fold",
145 phase = "piop_compiler",
146 round = self.round,
147 dimensions_data = ?dimensions_data,
148 )
149 .entered();
150 prover.fold(challenge)?;
151 }
152 self.round += 1;
153 Ok(())
154 }
155
156 pub fn finish<B>(mut self, transcript: &mut TranscriptWriter<B>) -> Result<Vec<Vec<F>>, Error>
158 where
159 B: BufMut,
160 {
161 self.finish_claim_provers(transcript)?;
162 if !self.provers.is_empty() {
163 return Err(Error::ExpectedFold);
164 }
165
166 Ok(self.multilinear_evals)
167 }
168
169 pub fn provers(&self) -> impl Iterator<Item = &Prover> {
171 self.provers.iter().map(|(prover, _)| prover)
172 }
173
174 pub fn run<Challenger_: Challenger>(
176 mut self,
177 transcript: &mut ProverTranscript<Challenger_>,
178 ) -> Result<BatchSumcheckOutput<F>, Error> {
179 let round_count = self.total_rounds();
180
181 let mut challenges = Vec::with_capacity(round_count);
182 for _round_no in 0..round_count {
183 self.send_round_proof(&mut transcript.message())?;
184
185 let challenge = transcript.sample();
186 challenges.push(challenge);
187
188 self.receive_challenge(challenge)?;
189 }
190
191 let multilinear_evals = self.finish(&mut transcript.message())?;
192
193 Ok(BatchSumcheckOutput {
194 challenges,
195 multilinear_evals,
196 })
197 }
198}