binius_core/protocols/sumcheck/prove/
batch_prove.rs1use std::iter;
4
5use binius_field::{Field, TowerField};
6use binius_math::EvaluationOrder;
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use tracing::instrument;
9
10use crate::{
11 fiat_shamir::{CanSample, Challenger},
12 protocols::sumcheck::{
13 common::{BatchSumcheckOutput, RoundCoeffs},
14 error::Error,
15 },
16 transcript::ProverTranscript,
17};
18
19pub trait SumcheckProver<F: Field> {
40 fn n_vars(&self) -> usize;
42
43 fn evaluation_order(&self) -> EvaluationOrder;
45
46 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
60
61 fn fold(&mut self, challenge: F) -> Result<(), Error>;
63
64 fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
67}
68
69impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
71 fn n_vars(&self) -> usize {
72 (**self).n_vars()
73 }
74
75 fn evaluation_order(&self) -> EvaluationOrder {
76 (**self).evaluation_order()
77 }
78
79 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
80 (**self).execute(batch_coeff)
81 }
82
83 fn fold(&mut self, challenge: F) -> Result<(), Error> {
84 (**self).fold(challenge)
85 }
86
87 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
88 (*self).finish()
89 }
90}
91
92#[instrument(skip_all, name = "sumcheck::batch_prove")]
101pub fn batch_prove<F, Prover, Challenger_>(
102 provers: Vec<Prover>,
103 transcript: &mut ProverTranscript<Challenger_>,
104) -> Result<BatchSumcheckOutput<F>, Error>
105where
106 F: TowerField,
107 Prover: SumcheckProver<F>,
108 Challenger_: Challenger,
109{
110 let start = BatchProveStart {
111 batch_coeffs: Vec::new(),
112 reduction_provers: Vec::<Prover>::new(),
113 };
114
115 batch_prove_with_start(start, provers, transcript)
116}
117
118#[derive(Debug)]
120pub struct BatchProveStart<F: Field, Prover> {
121 pub batch_coeffs: Vec<F>,
123 pub reduction_provers: Vec<Prover>,
125}
126
127#[instrument(skip_all, name = "sumcheck::batch_prove")]
129pub fn batch_prove_with_start<F, Prover, Challenger_>(
130 start: BatchProveStart<F, Prover>,
131 mut provers: Vec<Prover>,
132 transcript: &mut ProverTranscript<Challenger_>,
133) -> Result<BatchSumcheckOutput<F>, Error>
134where
135 F: TowerField,
136 Prover: SumcheckProver<F>,
137 Challenger_: Challenger,
138{
139 let BatchProveStart {
140 mut batch_coeffs,
141 reduction_provers,
142 } = start;
143
144 provers.splice(0..0, reduction_provers);
145
146 let Some(first_prover) = provers.first() else {
147 return Ok(BatchSumcheckOutput {
148 challenges: Vec::new(),
149 multilinear_evals: Vec::new(),
150 });
151 };
152
153 let evaluation_order = first_prover.evaluation_order();
154
155 if provers
156 .iter()
157 .any(|prover| prover.evaluation_order() != evaluation_order)
158 {
159 bail!(Error::InconsistentEvaluationOrder);
160 }
161
162 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
164 bail!(Error::ClaimsOutOfOrder);
165 }
166
167 if batch_coeffs.len() > provers.len() {
168 bail!(Error::TooManyPrebatchedCoeffs);
169 }
170
171 let n_rounds = provers
172 .iter()
173 .map(|prover| prover.n_vars())
174 .max()
175 .unwrap_or(0);
176
177 let mut active_index = batch_coeffs.len();
179 let mut challenges = Vec::with_capacity(n_rounds);
180 for round_no in 0..n_rounds {
181 let n_vars = n_rounds - round_no;
182
183 while let Some(prover) = provers.get(active_index) {
185 if prover.n_vars() != n_vars {
186 break;
187 }
188
189 let next_batch_coeff: F = transcript.sample();
190 batch_coeffs.push(next_batch_coeff);
191 active_index += 1;
192 }
193
194 let mut round_coeffs = RoundCoeffs::default();
196 for (&batch_coeff, prover) in
197 iter::zip(batch_coeffs.iter(), provers[..active_index].iter_mut())
198 {
199 let prover_coeffs = prover.execute(batch_coeff)?;
200 round_coeffs += &(prover_coeffs * batch_coeff);
201 }
202
203 let round_proof = round_coeffs.truncate();
204 transcript
205 .message()
206 .write_scalar_slice(round_proof.coeffs());
207
208 let challenge = transcript.sample();
209 challenges.push(challenge);
210
211 for prover in &mut provers[..active_index] {
212 prover.fold(challenge)?;
213 }
214 }
215
216 while let Some(prover) = provers.get(active_index) {
218 debug_assert_eq!(prover.n_vars(), 0);
219
220 let _next_batch_coeff: F = transcript.sample();
221 active_index += 1;
222 }
223
224 let multilinear_evals = provers
225 .into_iter()
226 .map(|prover| Box::new(prover).finish())
227 .collect::<Result<Vec<_>, _>>()?;
228
229 let mut writer = transcript.message();
230 for multilinear_evals in &multilinear_evals {
231 writer.write_scalar_slice(multilinear_evals);
232 }
233
234 if EvaluationOrder::HighToLow == evaluation_order {
235 challenges.reverse();
236 }
237
238 let output = BatchSumcheckOutput {
239 challenges,
240 multilinear_evals,
241 };
242
243 Ok(output)
244}