1use std::{iter, sync::Arc};
4
5use binius_field::{PackedField, PackedFieldIndexable, TowerField};
6use binius_hal::ComputationBackend;
7use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use tracing::instrument;
11
12use super::{
13 common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
14 eq_ind::RowBatchCoeffs,
15 error::Error,
16 tower_tensor_algebra::TowerTensorAlgebra,
17};
18use crate::{
19 fiat_shamir::{CanSample, Challenger},
20 piop::PIOPSumcheckClaim,
21 protocols::evalcheck::subclaims::MemoizedData,
22 ring_switch::{common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd},
23 tower::{PackedTop, TowerFamily},
24 transcript::ProverTranscript,
25 witness::MultilinearWitness,
26};
27
28type FExt<Tower> = <Tower as TowerFamily>::B128;
29
30#[derive(Debug)]
31pub struct ReducedWitness<P: PackedField> {
32 pub transparents: Vec<MultilinearWitness<'static, P>>,
33 pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
34}
35
36#[tracing::instrument("ring_switch::prove", skip_all)]
37pub fn prove<F, P, M, Tower, Challenger_, Backend>(
38 system: &EvalClaimSystem<F>,
39 witnesses: &[M],
40 transcript: &mut ProverTranscript<Challenger_>,
41 memoized_data: MemoizedData<P, Backend>,
42 backend: &Backend,
43) -> Result<ReducedWitness<P>, Error>
44where
45 F: TowerField,
46 P: PackedFieldIndexable<Scalar = F>,
47 M: MultilinearPoly<P> + Sync,
48 Tower: TowerFamily<B128 = F>,
49 F: PackedTop<Tower>,
50 Challenger_: Challenger,
51 Backend: ComputationBackend,
52{
53 if witnesses.len() != system.commit_meta.total_multilins() {
54 return Err(Error::InvalidWitness(
55 "witness length does not match the number of multilinears".into(),
56 ));
57 }
58
59 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
62 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
63 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
64
65 let tensor_elems =
67 compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, memoized_data, backend)?;
68 let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
69 let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
70 &scaled_tensor_elems,
71 &system.prefix_descs,
72 &system.eval_claim_to_prefix_desc_index,
73 )?;
74 let mut writer = transcript.message();
75 for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
76 debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
77 writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
78 }
79
80 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
82 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
83 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
84 ));
85
86 let row_batched_evals =
87 compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
88 transcript.message().write_scalar_slice(&row_batched_evals);
89
90 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
92 &system.sumcheck_claim_descs,
93 &system.suffix_descs,
94 row_batch_coeffs,
95 &mixing_coeffs,
96 )?;
97 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
98 .enumerate()
99 .map(|(idx, (claim_desc, eval))| {
100 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
101 PIOPSumcheckClaim {
102 n_vars: suffix_desc.suffix.len(),
103 committed: claim_desc.committed_idx,
104 transparent: idx,
105 sum: eval,
106 }
107 })
108 .collect::<Vec<_>>();
109
110 Ok(ReducedWitness {
111 transparents: ring_switch_eq_inds,
112 sumcheck_claims,
113 })
114}
115
116#[instrument(skip_all)]
117fn compute_partial_evals<F, P, M, Tower, Backend>(
118 system: &EvalClaimSystem<F>,
119 witnesses: &[M],
120 mut memoized_data: MemoizedData<P, Backend>,
121 backend: &Backend,
122) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
123where
124 F: TowerField,
125 P: PackedField<Scalar = F>,
126 M: MultilinearPoly<P> + Sync,
127 Tower: TowerFamily<B128 = F>,
128 Backend: ComputationBackend,
129{
130 let suffixes = system
131 .suffix_descs
132 .iter()
133 .map(|desc| Arc::as_ref(&desc.suffix))
134 .collect::<Vec<_>>();
135
136 memoized_data.memoize_query_par(&suffixes, backend)?;
137
138 let tensor_elems = system
139 .sumcheck_claim_descs
140 .par_iter()
141 .map(
142 |PIOPSumcheckClaimDesc {
143 committed_idx,
144 suffix_desc_idx,
145 eval_claim,
146 }| {
147 let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
148
149 let elems = if let Some(partial_eval) =
150 memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
151 {
152 PackedField::iter_slice(
153 partial_eval.packed_evals().expect("packed_evals exist"),
154 )
155 .take(1 << suffix_desc.kappa)
156 .collect()
157 } else {
158 let suffix_query = memoized_data
159 .full_query_readonly(&suffix_desc.suffix)
160 .expect("memoized above");
161 let partial_eval =
162 witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
163 PackedField::iter_slice(partial_eval.evals())
164 .take(1 << suffix_desc.kappa)
165 .collect()
166 };
167
168 TowerTensorAlgebra::new(suffix_desc.kappa, elems)
169 },
170 )
171 .collect::<Result<Vec<_>, _>>()?;
172
173 Ok(tensor_elems)
174}
175
176fn scale_tensor_elems<F, Tower>(
177 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
178 mixing_coeffs: &[F],
179) -> Vec<TowerTensorAlgebra<Tower>>
180where
181 F: TowerField,
182 Tower: TowerFamily<B128 = F>,
183{
184 assert!(tensor_elems.len() <= mixing_coeffs.len());
186
187 tensor_elems
188 .into_par_iter()
189 .zip(mixing_coeffs)
190 .map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
191 .collect()
192}
193
194fn mix_tensor_elems_for_prefixes<F, Tower>(
195 scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
196 prefix_descs: &[EvalClaimPrefixDesc<F>],
197 eval_claim_to_prefix_desc_index: &[usize],
198) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
199where
200 F: TowerField,
201 Tower: TowerFamily<B128 = F>,
202{
203 assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
205
206 let mut batched_tensor_elems = prefix_descs
207 .iter()
208 .map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
209 .collect::<Result<Vec<_>, _>>()?;
210 for (tensor_elem, &desc_index) in
211 iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
212 {
213 let mixed_val = &mut batched_tensor_elems[desc_index];
214 debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
215 mixed_val.add_assign(tensor_elem)?;
216 }
217 Ok(batched_tensor_elems)
218}
219
220#[instrument(skip_all)]
221fn compute_row_batched_sumcheck_evals<F, Tower>(
222 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
223 row_batch_coeffs: &[F],
224) -> Vec<F>
225where
226 F: TowerField,
227 Tower: TowerFamily<B128 = F>,
228 F: PackedTop<Tower>,
229{
230 tensor_elems
231 .into_par_iter()
232 .map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
233 .collect()
234}
235
236#[instrument(skip_all)]
237fn make_ring_switch_eq_inds<F, P, Tower>(
238 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
239 suffix_descs: &[EvalClaimSuffixDesc<F>],
240 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
241 mixing_coeffs: &[F],
242) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
243where
244 F: TowerField,
245 P: PackedFieldIndexable<Scalar = F>,
246 Tower: TowerFamily<B128 = F>,
247 F: PackedTop<Tower>,
248{
249 sumcheck_claim_descs
250 .par_iter()
251 .zip(mixing_coeffs)
252 .map(|(claim_desc, &mixing_coeff)| {
253 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
254 make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
255 })
256 .collect()
257}
258
259fn make_ring_switch_eq_ind<P, Tower>(
260 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
261 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
262 mixing_coeff: FExt<Tower>,
263) -> Result<MultilinearWitness<'static, P>, Error>
264where
265 P: PackedFieldIndexable<Scalar = FExt<Tower>>,
266 Tower: TowerFamily,
267{
268 let eq_ind = match suffix_desc.kappa {
269 7 => RingSwitchEqInd::<Tower::B1, _>::new(
270 suffix_desc.suffix.clone(),
271 row_batch_coeffs,
272 mixing_coeff,
273 )?
274 .multilinear_extension::<P>(),
275 4 => RingSwitchEqInd::<Tower::B8, _>::new(
276 suffix_desc.suffix.clone(),
277 row_batch_coeffs,
278 mixing_coeff,
279 )?
280 .multilinear_extension(),
281 3 => RingSwitchEqInd::<Tower::B16, _>::new(
282 suffix_desc.suffix.clone(),
283 row_batch_coeffs,
284 mixing_coeff,
285 )?
286 .multilinear_extension(),
287 2 => RingSwitchEqInd::<Tower::B32, _>::new(
288 suffix_desc.suffix.clone(),
289 row_batch_coeffs,
290 mixing_coeff,
291 )?
292 .multilinear_extension(),
293 1 => RingSwitchEqInd::<Tower::B64, _>::new(
294 suffix_desc.suffix.clone(),
295 row_batch_coeffs,
296 mixing_coeff,
297 )?
298 .multilinear_extension(),
299 0 => RingSwitchEqInd::<Tower::B128, _>::new(
300 suffix_desc.suffix.clone(),
301 row_batch_coeffs,
302 mixing_coeff,
303 )?
304 .multilinear_extension(),
305 _ => Err(Error::PackingDegreeNotSupported {
306 kappa: suffix_desc.kappa,
307 }),
308 }?;
309 Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
310}