1use binius_field::{BinaryField, ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use itertools::izip;
8use tracing::instrument;
9
10use super::{
11 common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, GKRExpProverBuilder, LayerClaim},
12 compositions::IndexedExpComposition,
13 error::Error,
14 provers::{
15 CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, StaticExpProver,
16 },
17 witness::BaseExpWitness,
18};
19use crate::{
20 fiat_shamir::Challenger,
21 protocols::sumcheck::{
22 self, BatchSumcheckOutput, CompositeSumClaim, immediate_switchover_heuristic,
23 },
24 transcript::ProverTranscript,
25 witness::MultilinearWitness,
26};
27
28#[instrument(skip_all, name = "gkr_exp::batch_prove")]
46pub fn batch_prove<'a, F, P, FDomain, Challenger_, Backend>(
47 evaluation_order: EvaluationOrder,
48 witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P>>,
49 claims: &[ExpClaim<F>],
50 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
51 transcript: &mut ProverTranscript<Challenger_>,
52 backend: &Backend,
53) -> Result<BaseExpReductionOutput<F>, Error>
54where
55 F: ExtensionField<FDomain> + TowerField,
56 FDomain: Field,
57 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
58 Backend: ComputationBackend,
59 Challenger_: Challenger,
60{
61 let witnesses = witnesses.into_iter().collect::<Vec<_>>();
62
63 if witnesses.len() != claims.len() {
64 bail!(Error::MismatchedWitnessClaimLength);
65 }
66
67 let mut layers_claims = Vec::new();
68
69 if witnesses.is_empty() {
70 return Ok(BaseExpReductionOutput { layers_claims });
71 }
72
73 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
75 bail!(Error::ClaimsOutOfOrder);
76 }
77
78 let mut provers = make_provers(witnesses, claims)?;
79
80 let max_exponent_bit_number = provers
81 .iter()
82 .map(|p| p.exponent_bit_width())
83 .max()
84 .unwrap_or(0);
85
86 for layer_no in 0..max_exponent_bit_number {
87 let _layer_span = tracing::info_span!(
88 "[task] GKR Exp Layer Sumcheck",
89 phase = "exp",
90 perfetto_category = "task.main"
91 )
92 .entered();
93 let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
94 evaluation_order,
95 &mut provers,
96 layer_no,
97 evaluation_domain_factory.clone(),
98 backend,
99 )?;
100
101 let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
102
103 let layer_exponent_claims = build_layer_exponent_bit_claims(
104 evaluation_order,
105 &mut provers,
106 sumcheck_proof_output,
107 layer_no,
108 )?;
109
110 layers_claims.push(layer_exponent_claims);
111
112 provers.retain(|prover| !prover.is_last_layer(layer_no));
113 }
114
115 Ok(BaseExpReductionOutput { layers_claims })
116}
117
118type GKRExpProvers<'a, F, P, FDomain, Backend> =
119 Vec<GKRExpProver<'a, FDomain, P, IndexedExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
120
121#[instrument(skip_all, level = "debug")]
123fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
124 evaluation_order: EvaluationOrder,
125 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
126 layer_no: usize,
127 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
128 backend: &'a Backend,
129) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
130where
131 FDomain: Field,
132 P: PackedField + PackedExtension<FDomain>,
133 P::Scalar: TowerField + ExtensionField<FDomain>,
134 Backend: ComputationBackend,
135{
136 assert!(!provers.is_empty());
137
138 let mut composite_claims = Vec::new();
139 let mut multilinears = Vec::new();
140
141 let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
142 let mut eval_points = vec![first_eval_point];
143
144 let mut active_index = 0;
145
146 for i in 0..provers.len() {
148 if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
149 let CompositeSumClaimsWithMultilinears {
150 composite_claims: eval_point_composite_claims,
151 multilinears: eval_point_multilinears,
152 } = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
153
154 if eval_point_composite_claims.is_empty() {
155 eval_points.pop();
158 } else {
159 composite_claims.push(eval_point_composite_claims);
160 multilinears.push(eval_point_multilinears);
161 }
162
163 eval_points.push(provers[i].layer_claim_eval_point().to_vec());
164 active_index = i;
165 }
166
167 if i == provers.len() - 1 {
168 let CompositeSumClaimsWithMultilinears {
169 composite_claims: eval_point_composite_claims,
170 multilinears: eval_point_multilinears,
171 } = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
172
173 if !eval_point_composite_claims.is_empty() {
174 composite_claims.push(eval_point_composite_claims);
175 multilinears.push(eval_point_multilinears);
176 }
177 }
178 }
179
180 izip!(composite_claims, multilinears, eval_points)
181 .map(|(composite_claims, multilinears, eval_point)| {
182 GKRExpProverBuilder::<'a, P, _, Backend>::with_switchover(
183 multilinears,
184 immediate_switchover_heuristic,
185 backend,
186 )?
187 .build(
188 evaluation_order,
189 &eval_point,
190 composite_claims,
191 evaluation_domain_factory.clone(),
192 )
193 })
194 .collect::<Result<Vec<_>, _>>()
195 .map_err(Error::from)
196}
197
198struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
199 composite_claims: Vec<CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>>,
200 multilinears: Vec<MultilinearWitness<'a, P>>,
201}
202
203fn build_eval_point_claims<'a, P>(
206 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
207 layer_no: usize,
208) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
209where
210 P: PackedField,
211{
212 let (composite_claims_n_multilinears, n_claims) =
213 provers
214 .iter()
215 .fold((0, 0), |(n_multilinears, n_claims), prover| {
216 let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
217 let layer_n_claims = prover.layer_n_claims(layer_no);
218
219 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
220 });
221
222 let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
223
224 let mut composite_claims = Vec::with_capacity(n_claims);
225
226 for prover in provers {
227 let multilinears_index = multilinears.len();
228
229 let meta = prover.layer_composite_sum_claim(
230 layer_no,
231 composite_claims_n_multilinears,
232 multilinears_index,
233 )?;
234
235 if let Some(meta) = meta {
236 let CompositeSumClaimWithMultilinears {
237 claim,
238 multilinears: this_layer_multilinears,
239 } = meta;
240
241 composite_claims.push(claim);
242
243 multilinears.extend(this_layer_multilinears);
244 }
245 }
246 Ok(CompositeSumClaimsWithMultilinears {
247 composite_claims,
248 multilinears,
249 })
250}
251
252fn build_layer_exponent_bit_claims<'a, P>(
255 evaluation_order: EvaluationOrder,
256 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
257 mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
258 layer_no: usize,
259) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
260where
261 P: PackedField,
262{
263 let mut eval_claims_on_exponent_bit_columns = Vec::new();
264
265 for multilinear_evals in &mut sumcheck_output.multilinear_evals {
267 multilinear_evals.pop();
268 }
269
270 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
271
272 for prover in provers {
273 let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
274
275 let this_prover_multilinear_evals = multilinear_evals
276 .by_ref()
277 .take(this_prover_n_multilinears)
278 .collect::<Vec<_>>();
279
280 let exponent_bit_claims = prover.finish_layer(
281 evaluation_order,
282 layer_no,
283 &this_prover_multilinear_evals,
284 &sumcheck_output.challenges,
285 );
286
287 eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
288 }
289
290 Ok(eval_claims_on_exponent_bit_columns)
291}
292
293fn make_provers<'a, P>(
295 witnesses: Vec<BaseExpWitness<'a, P>>,
296 claims: &[ExpClaim<P::Scalar>],
297) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
298where
299 P: PackedField,
300 P::Scalar: BinaryField,
301{
302 witnesses
303 .into_iter()
304 .zip(claims)
305 .map(|(witness, claim)| {
306 if witness.uses_dynamic_base() {
307 DynamicBaseExpProver::new(witness, claim)
308 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
309 } else {
310 StaticExpProver::<'a, P>::new(witness, claim)
311 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
312 }
313 })
314 .collect::<Result<Vec<_>, Error>>()
315}