binius_core/ring_switch/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{
6	PackedExtension, PackedField, PackedFieldIndexable, TowerField,
7	tower::{PackedTop, TowerFamily},
8};
9use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
10use binius_maybe_rayon::prelude::*;
11use binius_utils::checked_arithmetics::log2_ceil_usize;
12use tracing::instrument;
13
14use super::{
15	common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
16	eq_ind::RowBatchCoeffs,
17	error::Error,
18	logging::MLEFoldHisgDimensionsData,
19	tower_tensor_algebra::TowerTensorAlgebra,
20};
21use crate::{
22	fiat_shamir::{CanSample, Challenger},
23	piop::PIOPSumcheckClaim,
24	protocols::evalcheck::subclaims::MemoizedData,
25	ring_switch::{
26		common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd, logging::CalculateRingSwitchEqIndData,
27	},
28	transcript::ProverTranscript,
29	witness::MultilinearWitness,
30};
31
32type FExt<Tower> = <Tower as TowerFamily>::B128;
33
34#[derive(Debug)]
35pub struct ReducedWitness<P: PackedField> {
36	pub transparents: Vec<MultilinearWitness<'static, P>>,
37	pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
38}
39
40pub fn prove<F, P, M, Tower, Challenger_>(
41	system: &EvalClaimSystem<F>,
42	witnesses: &[M],
43	transcript: &mut ProverTranscript<Challenger_>,
44	memoized_data: MemoizedData<P>,
45) -> Result<ReducedWitness<P>, Error>
46where
47	F: TowerField + PackedTop<Tower>,
48	P: PackedFieldIndexable<Scalar = F>
49		+ PackedExtension<Tower::B1>
50		+ PackedExtension<Tower::B8>
51		+ PackedExtension<Tower::B16>
52		+ PackedExtension<Tower::B32>
53		+ PackedExtension<Tower::B64>
54		+ PackedExtension<Tower::B128>,
55	M: MultilinearPoly<P> + Sync,
56	Tower: TowerFamily<B128 = F>,
57	Challenger_: Challenger,
58{
59	if witnesses.len() != system.commit_meta.total_multilins() {
60		return Err(Error::InvalidWitness(
61			"witness length does not match the number of multilinears".into(),
62		));
63	}
64
65	// Sample enough randomness to batch tensor elements corresponding to claims that share an
66	// evaluation point prefix.
67	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
68	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
69	let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
70	let mle_fold_high_span = tracing::debug_span!(
71		"[task] (Ring Switch) MLE Fold High",
72		phase = "ring_switch",
73		perfetto_category = "task.main",
74		?dimensions_data,
75	)
76	.entered();
77
78	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
79
80	// For each evaluation point prefix, send one batched partial evaluation.
81	let tensor_elems = compute_partial_evals::<_, _, _, Tower>(system, witnesses, memoized_data)?;
82	let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
83	let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
84		&scaled_tensor_elems,
85		&system.prefix_descs,
86		&system.eval_claim_to_prefix_desc_index,
87	)?;
88	drop(mle_fold_high_span);
89	let mut writer = transcript.message();
90	for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
91		debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
92		writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
93	}
94
95	// Sample the row-batching randomness.
96	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
97	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
98		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
99	));
100
101	let row_batched_evals =
102		compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
103	transcript.message().write_scalar_slice(&row_batched_evals);
104
105	// Create the reduced PIOP sumcheck witnesses.
106	let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
107	let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
108		"[task] Calculate Ring Switch Eq Ind",
109		phase = "ring_switch",
110		perfetto_category = "task.main",
111		?dimensions_data,
112	)
113	.entered();
114
115	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
116		&system.sumcheck_claim_descs,
117		&system.suffix_descs,
118		row_batch_coeffs,
119		&mixing_coeffs,
120	)?;
121	drop(calculate_ring_switch_eq_ind_span);
122
123	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
124		.enumerate()
125		.map(|(idx, (claim_desc, eval))| {
126			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
127			PIOPSumcheckClaim {
128				n_vars: suffix_desc.suffix.len(),
129				committed: claim_desc.committed_idx,
130				transparent: idx,
131				sum: eval,
132			}
133		})
134		.collect::<Vec<_>>();
135
136	Ok(ReducedWitness {
137		transparents: ring_switch_eq_inds,
138		sumcheck_claims,
139	})
140}
141
142#[instrument(skip_all)]
143fn compute_partial_evals<F, P, M, Tower>(
144	system: &EvalClaimSystem<F>,
145	witnesses: &[M],
146	mut memoized_data: MemoizedData<P>,
147) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
148where
149	F: TowerField,
150	P: PackedField<Scalar = F>,
151	M: MultilinearPoly<P> + Sync,
152	Tower: TowerFamily<B128 = F>,
153{
154	let suffixes = system
155		.suffix_descs
156		.iter()
157		.map(|desc| Arc::as_ref(&desc.suffix))
158		.collect::<Vec<_>>();
159
160	memoized_data.memoize_query_par(suffixes)?;
161
162	let tensor_elems = system
163		.sumcheck_claim_descs
164		.par_iter()
165		.map(
166			|PIOPSumcheckClaimDesc {
167			     committed_idx,
168			     suffix_desc_idx,
169			     eval_claim,
170			 }| {
171				let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
172
173				let mut elems = if let Some(partial_eval) =
174					memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
175				{
176					PackedField::iter_slice(
177						partial_eval.packed_evals().expect("packed_evals exist"),
178					)
179					.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
180					.collect::<Vec<_>>()
181				} else {
182					let suffix_query = memoized_data
183						.full_query_readonly(&suffix_desc.suffix)
184						.expect("memoized above");
185					let partial_eval =
186						witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
187					PackedField::iter_slice(partial_eval.evals())
188						.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
189						.collect::<Vec<_>>()
190				};
191
192				if elems.len() < (1 << suffix_desc.kappa) {
193					elems = elems
194						.into_iter()
195						.cycle()
196						.take(1 << suffix_desc.kappa)
197						.collect();
198				}
199				TowerTensorAlgebra::new(suffix_desc.kappa, elems)
200			},
201		)
202		.collect::<Result<Vec<_>, _>>()?;
203
204	Ok(tensor_elems)
205}
206
207fn scale_tensor_elems<F, Tower>(
208	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
209	mixing_coeffs: &[F],
210) -> Vec<TowerTensorAlgebra<Tower>>
211where
212	F: TowerField,
213	Tower: TowerFamily<B128 = F>,
214{
215	// Precondition
216	assert!(tensor_elems.len() <= mixing_coeffs.len());
217
218	tensor_elems
219		.into_par_iter()
220		.zip(mixing_coeffs)
221		.map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
222		.collect()
223}
224
225fn mix_tensor_elems_for_prefixes<F, Tower>(
226	scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
227	prefix_descs: &[EvalClaimPrefixDesc<F>],
228	eval_claim_to_prefix_desc_index: &[usize],
229) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
230where
231	F: TowerField,
232	Tower: TowerFamily<B128 = F>,
233{
234	// Precondition
235	assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
236
237	let mut batched_tensor_elems = prefix_descs
238		.iter()
239		.map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
240		.collect::<Result<Vec<_>, _>>()?;
241	for (tensor_elem, &desc_index) in
242		iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
243	{
244		let mixed_val = &mut batched_tensor_elems[desc_index];
245		debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
246		mixed_val.add_assign(tensor_elem)?;
247	}
248	Ok(batched_tensor_elems)
249}
250
251#[instrument(skip_all)]
252fn compute_row_batched_sumcheck_evals<F, Tower>(
253	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
254	row_batch_coeffs: &[F],
255) -> Vec<F>
256where
257	F: TowerField,
258	Tower: TowerFamily<B128 = F>,
259	F: PackedTop<Tower>,
260{
261	tensor_elems
262		.into_par_iter()
263		.map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
264		.collect()
265}
266
267fn make_ring_switch_eq_inds<F, P, Tower>(
268	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
269	suffix_descs: &[EvalClaimSuffixDesc<F>],
270	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
271	mixing_coeffs: &[F],
272) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
273where
274	F: TowerField + PackedTop<Tower>,
275	P: PackedFieldIndexable<Scalar = F>
276		+ PackedExtension<Tower::B1>
277		+ PackedExtension<Tower::B8>
278		+ PackedExtension<Tower::B16>
279		+ PackedExtension<Tower::B32>
280		+ PackedExtension<Tower::B64>
281		+ PackedExtension<Tower::B128>,
282	Tower: TowerFamily<B128 = F>,
283{
284	sumcheck_claim_descs
285		.par_iter()
286		.zip(mixing_coeffs)
287		.map(|(claim_desc, &mixing_coeff)| {
288			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
289			make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
290		})
291		.collect()
292}
293
294fn make_ring_switch_eq_ind<P, Tower>(
295	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
296	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
297	mixing_coeff: FExt<Tower>,
298) -> Result<MultilinearWitness<'static, P>, Error>
299where
300	P: PackedFieldIndexable<Scalar = FExt<Tower>> + PackedTop<Tower>,
301	Tower: TowerFamily,
302{
303	let eq_ind = match suffix_desc.kappa {
304		7 => RingSwitchEqInd::<Tower::B1, _>::new(
305			suffix_desc.suffix.clone(),
306			row_batch_coeffs,
307			mixing_coeff,
308		)?
309		.multilinear_extension::<P>(),
310		4 => RingSwitchEqInd::<Tower::B8, _>::new(
311			suffix_desc.suffix.clone(),
312			row_batch_coeffs,
313			mixing_coeff,
314		)?
315		.multilinear_extension(),
316		3 => RingSwitchEqInd::<Tower::B16, _>::new(
317			suffix_desc.suffix.clone(),
318			row_batch_coeffs,
319			mixing_coeff,
320		)?
321		.multilinear_extension(),
322		2 => RingSwitchEqInd::<Tower::B32, _>::new(
323			suffix_desc.suffix.clone(),
324			row_batch_coeffs,
325			mixing_coeff,
326		)?
327		.multilinear_extension(),
328		1 => RingSwitchEqInd::<Tower::B64, _>::new(
329			suffix_desc.suffix.clone(),
330			row_batch_coeffs,
331			mixing_coeff,
332		)?
333		.multilinear_extension(),
334		0 => RingSwitchEqInd::<Tower::B128, _>::new(
335			suffix_desc.suffix.clone(),
336			row_batch_coeffs,
337			mixing_coeff,
338		)?
339		.multilinear_extension(),
340		_ => Err(Error::PackingDegreeNotSupported {
341			kappa: suffix_desc.kappa,
342		}),
343	}?;
344	Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
345}