1use binius_field::{BinaryField, ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use itertools::izip;
8use tracing::instrument;
9
10use super::{
11 common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, GKRExpProverBuilder, LayerClaim},
12 compositions::IndexedExpComposition,
13 error::Error,
14 provers::{
15 CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, StaticExpProver,
16 },
17 witness::BaseExpWitness,
18};
19use crate::{
20 fiat_shamir::Challenger,
21 protocols::sumcheck::{
22 self, immediate_switchover_heuristic, BatchSumcheckOutput, CompositeSumClaim,
23 },
24 transcript::ProverTranscript,
25 witness::MultilinearWitness,
26};
27
28#[instrument(skip_all, name = "gkr_exp::batch_prove")]
45pub fn batch_prove<'a, F, P, FDomain, Challenger_, Backend>(
46 evaluation_order: EvaluationOrder,
47 witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P>>,
48 claims: &[ExpClaim<F>],
49 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
50 transcript: &mut ProverTranscript<Challenger_>,
51 backend: &Backend,
52) -> Result<BaseExpReductionOutput<F>, Error>
53where
54 F: ExtensionField<FDomain> + TowerField,
55 FDomain: Field,
56 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
57 Backend: ComputationBackend,
58 Challenger_: Challenger,
59{
60 let witnesses = witnesses.into_iter().collect::<Vec<_>>();
61
62 if witnesses.len() != claims.len() {
63 bail!(Error::MismatchedWitnessClaimLength);
64 }
65
66 let mut layers_claims = Vec::new();
67
68 if witnesses.is_empty() {
69 return Ok(BaseExpReductionOutput { layers_claims });
70 }
71
72 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
74 bail!(Error::ClaimsOutOfOrder);
75 }
76
77 let mut provers = make_provers(witnesses, claims)?;
78
79 let max_exponent_bit_number = provers
80 .iter()
81 .map(|p| p.exponent_bit_width())
82 .max()
83 .unwrap_or(0);
84
85 for layer_no in 0..max_exponent_bit_number {
86 let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
87 evaluation_order,
88 &mut provers,
89 layer_no,
90 evaluation_domain_factory.clone(),
91 backend,
92 )?;
93
94 let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
95
96 let layer_exponent_claims = build_layer_exponent_bit_claims(
97 evaluation_order,
98 &mut provers,
99 sumcheck_proof_output,
100 layer_no,
101 )?;
102
103 layers_claims.push(layer_exponent_claims);
104
105 provers.retain(|prover| !prover.is_last_layer(layer_no));
106 }
107
108 Ok(BaseExpReductionOutput { layers_claims })
109}
110
111type GKRExpProvers<'a, F, P, FDomain, Backend> =
112 Vec<GKRExpProver<'a, FDomain, P, IndexedExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
113
114#[instrument(skip_all, level = "debug")]
116fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
117 evaluation_order: EvaluationOrder,
118 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
119 layer_no: usize,
120 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
121 backend: &'a Backend,
122) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
123where
124 FDomain: Field,
125 P: PackedField + PackedExtension<FDomain>,
126 P::Scalar: TowerField + ExtensionField<FDomain>,
127 Backend: ComputationBackend,
128{
129 assert!(!provers.is_empty());
130
131 let mut composite_claims = Vec::new();
132 let mut multilinears = Vec::new();
133
134 let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
135 let mut eval_points = vec![first_eval_point];
136
137 let mut active_index = 0;
138
139 for i in 0..provers.len() {
141 if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
142 let CompositeSumClaimsWithMultilinears {
143 composite_claims: eval_point_composite_claims,
144 multilinears: eval_point_multilinears,
145 } = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
146
147 if eval_point_composite_claims.is_empty() {
148 eval_points.pop();
150 } else {
151 composite_claims.push(eval_point_composite_claims);
152 multilinears.push(eval_point_multilinears);
153 }
154
155 eval_points.push(provers[i].layer_claim_eval_point().to_vec());
156 active_index = i;
157 }
158
159 if i == provers.len() - 1 {
160 let CompositeSumClaimsWithMultilinears {
161 composite_claims: eval_point_composite_claims,
162 multilinears: eval_point_multilinears,
163 } = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
164
165 if !eval_point_composite_claims.is_empty() {
166 composite_claims.push(eval_point_composite_claims);
167 multilinears.push(eval_point_multilinears);
168 }
169 }
170 }
171
172 izip!(composite_claims, multilinears, eval_points)
173 .map(|(composite_claims, multilinears, eval_point)| {
174 GKRExpProverBuilder::<'a, P, _, Backend>::with_switchover(
175 multilinears,
176 immediate_switchover_heuristic,
177 backend,
178 )?
179 .build(
180 evaluation_order,
181 &eval_point,
182 composite_claims,
183 evaluation_domain_factory.clone(),
184 )
185 })
186 .collect::<Result<Vec<_>, _>>()
187 .map_err(Error::from)
188}
189
190struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
191 composite_claims: Vec<CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>>,
192 multilinears: Vec<MultilinearWitness<'a, P>>,
193}
194
195fn build_eval_point_claims<'a, P>(
197 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
198 layer_no: usize,
199) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
200where
201 P: PackedField,
202{
203 let (composite_claims_n_multilinears, n_claims) =
204 provers
205 .iter()
206 .fold((0, 0), |(n_multilinears, n_claims), prover| {
207 let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
208 let layer_n_claims = prover.layer_n_claims(layer_no);
209
210 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
211 });
212
213 let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
214
215 let mut composite_claims = Vec::with_capacity(n_claims);
216
217 for prover in provers {
218 let multilinears_index = multilinears.len();
219
220 let meta = prover.layer_composite_sum_claim(
221 layer_no,
222 composite_claims_n_multilinears,
223 multilinears_index,
224 )?;
225
226 if let Some(meta) = meta {
227 let CompositeSumClaimWithMultilinears {
228 claim,
229 multilinears: this_layer_multilinears,
230 } = meta;
231
232 composite_claims.push(claim);
233
234 multilinears.extend(this_layer_multilinears);
235 }
236 }
237 Ok(CompositeSumClaimsWithMultilinears {
238 composite_claims,
239 multilinears,
240 })
241}
242
243fn build_layer_exponent_bit_claims<'a, P>(
245 evaluation_order: EvaluationOrder,
246 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
247 mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
248 layer_no: usize,
249) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
250where
251 P: PackedField,
252{
253 let mut eval_claims_on_exponent_bit_columns = Vec::new();
254
255 for multilinear_evals in &mut sumcheck_output.multilinear_evals {
257 multilinear_evals.pop();
258 }
259
260 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
261
262 for prover in provers {
263 let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
264
265 let this_prover_multilinear_evals = multilinear_evals
266 .by_ref()
267 .take(this_prover_n_multilinears)
268 .collect::<Vec<_>>();
269
270 let exponent_bit_claims = prover.finish_layer(
271 evaluation_order,
272 layer_no,
273 &this_prover_multilinear_evals,
274 &sumcheck_output.challenges,
275 );
276
277 eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
278 }
279
280 Ok(eval_claims_on_exponent_bit_columns)
281}
282
283fn make_provers<'a, P>(
285 witnesses: Vec<BaseExpWitness<'a, P>>,
286 claims: &[ExpClaim<P::Scalar>],
287) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
288where
289 P: PackedField,
290 P::Scalar: BinaryField,
291{
292 witnesses
293 .into_iter()
294 .zip(claims)
295 .map(|(witness, claim)| {
296 if witness.uses_dynamic_base() {
297 DynamicBaseExpProver::new(witness, claim)
298 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
299 } else {
300 StaticExpProver::<'a, P>::new(witness, claim)
301 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
302 }
303 })
304 .collect::<Result<Vec<_>, Error>>()
305}