binius_core/protocols/sumcheck/prove/
batch_sumcheck.rs1use binius_field::{Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use itertools::izip;
7use tracing::instrument;
8
9use crate::{
10 fiat_shamir::{CanSample, Challenger},
11 protocols::sumcheck::{
12 common::{BatchSumcheckOutput, RoundCoeffs},
13 error::Error,
14 },
15 transcript::ProverTranscript,
16};
17
18pub trait SumcheckProver<F: Field> {
39 fn n_vars(&self) -> usize;
41
42 fn evaluation_order(&self) -> EvaluationOrder;
44
45 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
59
60 fn fold(&mut self, challenge: F) -> Result<(), Error>;
62
63 fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
66}
67
68impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
70 fn n_vars(&self) -> usize {
71 (**self).n_vars()
72 }
73
74 fn evaluation_order(&self) -> EvaluationOrder {
75 (**self).evaluation_order()
76 }
77
78 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
79 (**self).execute(batch_coeff)
80 }
81
82 fn fold(&mut self, challenge: F) -> Result<(), Error> {
83 (**self).fold(challenge)
84 }
85
86 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
87 (*self).finish()
88 }
89}
90
91#[instrument(skip_all, name = "sumcheck::batch_prove")]
100pub fn batch_prove<F, Prover, Challenger_>(
101 mut provers: Vec<Prover>,
102 transcript: &mut ProverTranscript<Challenger_>,
103) -> Result<BatchSumcheckOutput<F>, Error>
104where
105 F: TowerField,
106 Prover: SumcheckProver<F>,
107 Challenger_: Challenger,
108{
109 let Some(first_prover) = provers.first() else {
110 return Ok(BatchSumcheckOutput {
111 challenges: Vec::new(),
112 multilinear_evals: Vec::new(),
113 });
114 };
115
116 let evaluation_order = first_prover.evaluation_order();
117
118 if provers
119 .iter()
120 .any(|prover| prover.evaluation_order() != evaluation_order)
121 {
122 bail!(Error::InconsistentEvaluationOrder);
123 }
124
125 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
127 bail!(Error::ClaimsOutOfOrder);
128 }
129
130 let n_rounds = provers
131 .iter()
132 .map(|prover| prover.n_vars())
133 .max()
134 .unwrap_or(0);
135
136 let mut batch_coeffs = Vec::with_capacity(provers.len());
137 let mut challenges = Vec::with_capacity(n_rounds);
138 for round_no in 0..n_rounds {
139 let n_vars = n_rounds - round_no;
140
141 while let Some(prover) = provers.get(batch_coeffs.len()) {
143 if prover.n_vars() != n_vars {
144 break;
145 }
146
147 let next_batch_coeff = transcript.sample();
148 batch_coeffs.push(next_batch_coeff);
149 }
150
151 let mut round_coeffs = RoundCoeffs::default();
153 for (&batch_coeff, prover) in izip!(&batch_coeffs, &mut provers) {
154 let prover_coeffs = prover.execute(batch_coeff)?;
155 round_coeffs += &(prover_coeffs * batch_coeff);
156 }
157
158 let round_proof = round_coeffs.truncate();
159 transcript
160 .message()
161 .write_scalar_slice(round_proof.coeffs());
162
163 let challenge = transcript.sample();
164 challenges.push(challenge);
165
166 for prover in &mut provers[..batch_coeffs.len()] {
167 prover.fold(challenge)?;
168 }
169 }
170
171 for prover in &provers[batch_coeffs.len()..] {
173 debug_assert_eq!(prover.n_vars(), 0);
174 let _next_batch_coeff: F = transcript.sample();
175 }
176
177 let multilinear_evals = provers
178 .into_iter()
179 .map(|prover| Box::new(prover).finish())
180 .collect::<Result<Vec<_>, _>>()?;
181
182 let mut writer = transcript.message();
183 for multilinear_evals in &multilinear_evals {
184 writer.write_scalar_slice(multilinear_evals);
185 }
186
187 if EvaluationOrder::HighToLow == evaluation_order {
188 challenges.reverse();
189 }
190
191 let output = BatchSumcheckOutput {
192 challenges,
193 multilinear_evals,
194 };
195
196 Ok(output)
197}