1use std::{iter, sync::Arc};
4
5use binius_field::{PackedField, PackedFieldIndexable, TowerField};
6use binius_hal::{ComputationBackend, ComputationBackendExt};
7use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use tracing::instrument;
11
12use super::{
13 common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
14 eq_ind::RowBatchCoeffs,
15 error::Error,
16 tower_tensor_algebra::TowerTensorAlgebra,
17};
18use crate::{
19 fiat_shamir::{CanSample, Challenger},
20 piop::PIOPSumcheckClaim,
21 ring_switch::{common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd},
22 tower::{PackedTop, TowerFamily},
23 transcript::ProverTranscript,
24 witness::MultilinearWitness,
25};
26
27type FExt<Tower> = <Tower as TowerFamily>::B128;
28
29#[derive(Debug)]
30pub struct ReducedWitness<P: PackedField> {
31 pub transparents: Vec<MultilinearWitness<'static, P>>,
32 pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
33}
34
35#[tracing::instrument("ring_switch::prove", skip_all)]
36pub fn prove<F, P, M, Tower, Challenger_, Backend>(
37 system: &EvalClaimSystem<F>,
38 witnesses: &[M],
39 transcript: &mut ProverTranscript<Challenger_>,
40 backend: &Backend,
41) -> Result<ReducedWitness<P>, Error>
42where
43 F: TowerField,
44 P: PackedFieldIndexable<Scalar = F>,
45 M: MultilinearPoly<P> + Sync,
46 Tower: TowerFamily<B128 = F>,
47 F: PackedTop<Tower>,
48 Challenger_: Challenger,
49 Backend: ComputationBackend,
50{
51 if witnesses.len() != system.commit_meta.total_multilins() {
52 return Err(Error::InvalidWitness(
53 "witness length does not match the number of multilinears".into(),
54 ));
55 }
56
57 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
60 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
61 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
62
63 let tensor_elems = compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, backend)?;
65 let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
66 let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
67 &scaled_tensor_elems,
68 &system.prefix_descs,
69 &system.eval_claim_to_prefix_desc_index,
70 )?;
71 let mut writer = transcript.message();
72 for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
73 debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
74 writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
75 }
76
77 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
79 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
80 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
81 ));
82
83 let row_batched_evals =
84 compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
85 transcript.message().write_scalar_slice(&row_batched_evals);
86
87 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
89 &system.sumcheck_claim_descs,
90 &system.suffix_descs,
91 row_batch_coeffs,
92 &mixing_coeffs,
93 )?;
94 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
95 .enumerate()
96 .map(|(idx, (claim_desc, eval))| {
97 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
98 PIOPSumcheckClaim {
99 n_vars: suffix_desc.suffix.len(),
100 committed: claim_desc.committed_idx,
101 transparent: idx,
102 sum: eval,
103 }
104 })
105 .collect::<Vec<_>>();
106
107 Ok(ReducedWitness {
108 transparents: ring_switch_eq_inds,
109 sumcheck_claims,
110 })
111}
112
113#[instrument(skip_all)]
114fn compute_partial_evals<F, P, M, Tower, Backend>(
115 system: &EvalClaimSystem<F>,
116 witnesses: &[M],
117 backend: &Backend,
118) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
119where
120 F: TowerField,
121 P: PackedField<Scalar = F>,
122 M: MultilinearPoly<P> + Sync,
123 Tower: TowerFamily<B128 = F>,
124 Backend: ComputationBackend,
125{
126 let suffix_queries = system
127 .suffix_descs
128 .par_iter()
129 .map(|desc| backend.multilinear_query(&desc.suffix))
130 .collect::<Result<Vec<_>, _>>()?;
131
132 let tensor_elems = system
133 .sumcheck_claim_descs
134 .par_iter()
135 .map(
136 |PIOPSumcheckClaimDesc {
137 committed_idx,
138 suffix_desc_idx,
139 eval_claim: _,
140 }| {
141 let suffix = &system.suffix_descs[*suffix_desc_idx];
142 let suffix_query = &suffix_queries[*suffix_desc_idx];
143 let partial_eval =
144 witnesses[*committed_idx].evaluate_partial_high(suffix_query.to_ref())?;
145 TowerTensorAlgebra::new(
146 suffix.kappa,
147 PackedField::iter_slice(partial_eval.evals())
148 .take(1 << suffix.kappa)
149 .collect(),
150 )
151 },
152 )
153 .collect::<Result<Vec<_>, _>>()?;
154
155 Ok(tensor_elems)
156}
157
158fn scale_tensor_elems<F, Tower>(
159 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
160 mixing_coeffs: &[F],
161) -> Vec<TowerTensorAlgebra<Tower>>
162where
163 F: TowerField,
164 Tower: TowerFamily<B128 = F>,
165{
166 assert!(tensor_elems.len() <= mixing_coeffs.len());
168
169 tensor_elems
170 .into_par_iter()
171 .zip(mixing_coeffs)
172 .map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
173 .collect()
174}
175
176fn mix_tensor_elems_for_prefixes<F, Tower>(
177 scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
178 prefix_descs: &[EvalClaimPrefixDesc<F>],
179 eval_claim_to_prefix_desc_index: &[usize],
180) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
181where
182 F: TowerField,
183 Tower: TowerFamily<B128 = F>,
184{
185 assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
187
188 let mut batched_tensor_elems = prefix_descs
189 .iter()
190 .map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
191 .collect::<Result<Vec<_>, _>>()?;
192 for (tensor_elem, &desc_index) in
193 iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
194 {
195 let mixed_val = &mut batched_tensor_elems[desc_index];
196 debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
197 mixed_val.add_assign(tensor_elem)?;
198 }
199 Ok(batched_tensor_elems)
200}
201
202#[instrument(skip_all)]
203fn compute_row_batched_sumcheck_evals<F, Tower>(
204 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
205 row_batch_coeffs: &[F],
206) -> Vec<F>
207where
208 F: TowerField,
209 Tower: TowerFamily<B128 = F>,
210 F: PackedTop<Tower>,
211{
212 tensor_elems
213 .into_par_iter()
214 .map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
215 .collect()
216}
217
218#[instrument(skip_all)]
219fn make_ring_switch_eq_inds<F, P, Tower>(
220 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
221 suffix_descs: &[EvalClaimSuffixDesc<F>],
222 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
223 mixing_coeffs: &[F],
224) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
225where
226 F: TowerField,
227 P: PackedFieldIndexable<Scalar = F>,
228 Tower: TowerFamily<B128 = F>,
229 F: PackedTop<Tower>,
230{
231 sumcheck_claim_descs
232 .par_iter()
233 .zip(mixing_coeffs)
234 .map(|(claim_desc, &mixing_coeff)| {
235 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
236 make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
237 })
238 .collect()
239}
240
241fn make_ring_switch_eq_ind<P, Tower>(
242 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
243 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
244 mixing_coeff: FExt<Tower>,
245) -> Result<MultilinearWitness<'static, P>, Error>
246where
247 P: PackedFieldIndexable<Scalar = FExt<Tower>>,
248 Tower: TowerFamily,
249{
250 let eq_ind = match suffix_desc.kappa {
251 7 => RingSwitchEqInd::<Tower::B1, _>::new(
252 suffix_desc.suffix.clone(),
253 row_batch_coeffs,
254 mixing_coeff,
255 )?
256 .multilinear_extension::<P>(),
257 4 => RingSwitchEqInd::<Tower::B8, _>::new(
258 suffix_desc.suffix.clone(),
259 row_batch_coeffs,
260 mixing_coeff,
261 )?
262 .multilinear_extension(),
263 3 => RingSwitchEqInd::<Tower::B16, _>::new(
264 suffix_desc.suffix.clone(),
265 row_batch_coeffs,
266 mixing_coeff,
267 )?
268 .multilinear_extension(),
269 2 => RingSwitchEqInd::<Tower::B32, _>::new(
270 suffix_desc.suffix.clone(),
271 row_batch_coeffs,
272 mixing_coeff,
273 )?
274 .multilinear_extension(),
275 1 => RingSwitchEqInd::<Tower::B64, _>::new(
276 suffix_desc.suffix.clone(),
277 row_batch_coeffs,
278 mixing_coeff,
279 )?
280 .multilinear_extension(),
281 0 => RingSwitchEqInd::<Tower::B128, _>::new(
282 suffix_desc.suffix.clone(),
283 row_batch_coeffs,
284 mixing_coeff,
285 )?
286 .multilinear_extension(),
287 _ => Err(Error::PackingDegreeNotSupported {
288 kappa: suffix_desc.kappa,
289 }),
290 }?;
291 Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
292}