binius_core/ring_switch/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_compute::{
6	ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, alloc::ComputeAllocator,
7	cpu::CpuMemory, layer,
8};
9use binius_field::{Field, PackedField, PackedFieldIndexable};
10use binius_math::{
11	B1, B8, B16, B32, B64, B128, MultilinearPoly, MultilinearQuery, PackedTop, TowerTop,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_utils::checked_arithmetics::log2_ceil_usize;
15use itertools::izip;
16use tracing::instrument;
17
18use super::{
19	common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
20	eq_ind::RowBatchCoeffs,
21	error::Error,
22	logging::MLEFoldHisgDimensionsData,
23	tower_tensor_algebra::TowerTensorAlgebra,
24};
25use crate::{
26	fiat_shamir::{CanSample, Challenger},
27	piop::PIOPSumcheckClaim,
28	protocols::evalcheck::subclaims::MemoizedData,
29	ring_switch::{
30		common::EvalClaimSuffixDesc,
31		eq_ind::{RingSwitchEqInd, RingSwitchEqIndPrecompute},
32		logging::CalculateRingSwitchEqIndData,
33	},
34	transcript::ProverTranscript,
35};
36
37pub struct ReducedWitness<'a, F: Field, Hal: ComputeLayer<F>> {
38	pub transparents: Vec<FSlice<'a, F, Hal>>,
39	pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
40}
41
42pub fn prove<'a, F, P, M, Challenger_, Hal, HostAllocatorType, DeviceAllocatorType>(
43	system: &EvalClaimSystem<F>,
44	witnesses: &[M],
45	transcript: &mut ProverTranscript<Challenger_>,
46	memoized_data: MemoizedData<P>,
47	hal: &Hal,
48	dev_alloc: &'a DeviceAllocatorType,
49	host_alloc: &HostAllocatorType,
50) -> Result<ReducedWitness<'a, F, Hal>, Error>
51where
52	F: TowerTop + PackedTop<Scalar = F>,
53	P: PackedFieldIndexable<Scalar = F> + PackedTop,
54	M: MultilinearPoly<P> + Sync,
55	Challenger_: Challenger,
56	Hal: ComputeLayer<F>,
57	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
58	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
59{
60	if witnesses.len() != system.commit_meta.total_multilins() {
61		return Err(Error::InvalidWitness(
62			"witness length does not match the number of multilinears".into(),
63		));
64	}
65
66	// Sample enough randomness to batch tensor elements corresponding to claims that share an
67	// evaluation point prefix.
68	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
69	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
70	let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
71	let mle_fold_high_span = tracing::debug_span!(
72		"[task] (Ring Switch) MLE Fold High",
73		phase = "ring_switch",
74		perfetto_category = "task.main",
75		?dimensions_data,
76	)
77	.entered();
78
79	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
80
81	// For each evaluation point prefix, send one batched partial evaluation.
82	let tensor_elems = compute_partial_evals(system, witnesses, memoized_data)?;
83	let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
84	let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
85		&scaled_tensor_elems,
86		&system.prefix_descs,
87		&system.eval_claim_to_prefix_desc_index,
88	)?;
89	drop(mle_fold_high_span);
90	let mut writer = transcript.message();
91	for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
92		debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
93		writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
94	}
95
96	// Sample the row-batching randomness.
97	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
98	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
99		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
100	));
101
102	let row_batched_evals =
103		compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
104	transcript.message().write_scalar_slice(&row_batched_evals);
105
106	// Create the reduced PIOP sumcheck witnesses.
107	let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
108	let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
109		"[task] Calculate Ring Switch Eq Ind",
110		phase = "ring_switch",
111		perfetto_category = "task.main",
112		?dimensions_data,
113	)
114	.entered();
115
116	let ring_switch_eq_inds = make_ring_switch_eq_inds(
117		&system.sumcheck_claim_descs,
118		&system.suffix_descs,
119		row_batch_coeffs,
120		&mixing_coeffs,
121		hal,
122		dev_alloc,
123		host_alloc,
124	)?;
125	drop(calculate_ring_switch_eq_ind_span);
126
127	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
128		.enumerate()
129		.map(|(idx, (claim_desc, eval))| {
130			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
131			PIOPSumcheckClaim {
132				n_vars: suffix_desc.suffix.len(),
133				committed: claim_desc.committed_idx,
134				transparent: idx,
135				sum: eval,
136			}
137		})
138		.collect::<Vec<_>>();
139
140	Ok(ReducedWitness {
141		transparents: ring_switch_eq_inds,
142		sumcheck_claims,
143	})
144}
145
146#[instrument(skip_all)]
147fn compute_partial_evals<F, P, M>(
148	system: &EvalClaimSystem<F>,
149	witnesses: &[M],
150	mut memoized_data: MemoizedData<P>,
151) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
152where
153	F: TowerTop,
154	P: PackedField<Scalar = F>,
155	M: MultilinearPoly<P> + Sync,
156{
157	let suffixes = system
158		.suffix_descs
159		.iter()
160		.map(|desc| Arc::as_ref(&desc.suffix))
161		.collect::<Vec<_>>();
162
163	memoized_data.memoize_query_par(suffixes)?;
164
165	let tensor_elems = system
166		.sumcheck_claim_descs
167		.par_iter()
168		.map(
169			|PIOPSumcheckClaimDesc {
170			     committed_idx,
171			     suffix_desc_idx,
172			     eval_claim,
173			 }| {
174				let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
175
176				let mut elems = if let Some(partial_eval) =
177					memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
178				{
179					PackedField::iter_slice(
180						partial_eval.packed_evals().expect("packed_evals exist"),
181					)
182					.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
183					.collect::<Vec<_>>()
184				} else {
185					let suffix_query = memoized_data
186						.full_query_readonly(&suffix_desc.suffix)
187						.expect("memoized above");
188					let partial_eval =
189						witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
190					PackedField::iter_slice(partial_eval.evals())
191						.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
192						.collect::<Vec<_>>()
193				};
194
195				if elems.len() < (1 << suffix_desc.kappa) {
196					elems = elems
197						.into_iter()
198						.cycle()
199						.take(1 << suffix_desc.kappa)
200						.collect();
201				}
202				TowerTensorAlgebra::new(suffix_desc.kappa, elems)
203			},
204		)
205		.collect::<Result<Vec<_>, _>>()?;
206
207	Ok(tensor_elems)
208}
209
210fn scale_tensor_elems<F>(
211	tensor_elems: Vec<TowerTensorAlgebra<F>>,
212	mixing_coeffs: &[F],
213) -> Vec<TowerTensorAlgebra<F>>
214where
215	F: TowerTop,
216{
217	// Precondition
218	assert!(tensor_elems.len() <= mixing_coeffs.len());
219
220	tensor_elems
221		.into_par_iter()
222		.zip(mixing_coeffs)
223		.map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
224		.collect()
225}
226
227fn mix_tensor_elems_for_prefixes<F>(
228	scaled_tensor_elems: &[TowerTensorAlgebra<F>],
229	prefix_descs: &[EvalClaimPrefixDesc<F>],
230	eval_claim_to_prefix_desc_index: &[usize],
231) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
232where
233	F: TowerTop,
234{
235	// Precondition
236	assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
237
238	let mut batched_tensor_elems = prefix_descs
239		.iter()
240		.map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
241		.collect::<Result<Vec<_>, _>>()?;
242	for (tensor_elem, &desc_index) in
243		iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
244	{
245		let mixed_val = &mut batched_tensor_elems[desc_index];
246		debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
247		mixed_val.add_assign(tensor_elem)?;
248	}
249	Ok(batched_tensor_elems)
250}
251
252#[instrument(skip_all)]
253fn compute_row_batched_sumcheck_evals<F>(
254	tensor_elems: Vec<TowerTensorAlgebra<F>>,
255	row_batch_coeffs: &[F],
256) -> Vec<F>
257where
258	F: TowerTop + PackedTop,
259{
260	tensor_elems
261		.into_par_iter()
262		.map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
263		.collect()
264}
265
266fn make_ring_switch_eq_inds<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>(
267	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
268	suffix_descs: &[EvalClaimSuffixDesc<F>],
269	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
270	mixing_coeffs: &[F],
271	hal: &Hal,
272	dev_alloc: &'a DeviceAllocatorType,
273	host_alloc: &HostAllocatorType,
274) -> Result<Vec<FSlice<'a, F, Hal>>, Error>
275where
276	F: TowerTop,
277	Hal: ComputeLayer<F>,
278	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
279	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
280{
281	let mut eq_inds = Vec::with_capacity(sumcheck_claim_descs.len());
282
283	let precompute = sumcheck_claim_descs
284		.iter()
285		.zip(mixing_coeffs)
286		.map(|(claim_desc, mixing_coeffs)| {
287			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
288			RingSwitchEqInd::<F, F>::precompute_values(
289				suffix_desc.suffix.clone(),
290				row_batch_coeffs.clone(),
291				*mixing_coeffs,
292				suffix_desc.kappa,
293				hal,
294				dev_alloc,
295				host_alloc,
296			)
297		})
298		.collect::<Result<Vec<_>, Error>>()?;
299
300	let _ = hal.execute(|exec| {
301		let res = exec
302			.map(
303				izip!(sumcheck_claim_descs, mixing_coeffs, precompute),
304				|exec, (claim_desc, &mixing_coeff, precompute)| {
305					let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
306
307					make_ring_switch_eq_ind(
308						suffix_desc,
309						row_batch_coeffs.clone(),
310						mixing_coeff,
311						exec,
312						precompute,
313					)
314					.map_err(|e| layer::Error::CoreLibError(Box::new(e)))
315				},
316			)
317			.map_err(|e| layer::Error::CoreLibError(Box::new(e)))?;
318
319		eq_inds.extend(res);
320		Ok(vec![])
321	})?;
322
323	Ok(eq_inds)
324}
325
326fn make_ring_switch_eq_ind<'a, F, Mem, Exec>(
327	suffix_desc: &EvalClaimSuffixDesc<F>,
328	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
329	mixing_coeff: F,
330	exec: &mut Exec,
331	precompute: RingSwitchEqIndPrecompute<'a, F, Mem>,
332) -> Result<Mem::FSlice<'a>, Error>
333where
334	F: TowerTop,
335	Mem: ComputeMemory<F>,
336	Exec: ComputeLayerExecutor<F, DevMem = Mem>,
337{
338	let eq_ind = match F::TOWER_LEVEL - suffix_desc.kappa {
339		0 => RingSwitchEqInd::<B1, _>::new(
340			suffix_desc.suffix.clone(),
341			row_batch_coeffs,
342			mixing_coeff,
343		)?
344		.multilinear_extension(precompute, exec, 0),
345		3 => RingSwitchEqInd::<B8, _>::new(
346			suffix_desc.suffix.clone(),
347			row_batch_coeffs,
348			mixing_coeff,
349		)?
350		.multilinear_extension(precompute, exec, 3),
351		4 => RingSwitchEqInd::<B16, _>::new(
352			suffix_desc.suffix.clone(),
353			row_batch_coeffs,
354			mixing_coeff,
355		)?
356		.multilinear_extension(precompute, exec, 4),
357		5 => RingSwitchEqInd::<B32, _>::new(
358			suffix_desc.suffix.clone(),
359			row_batch_coeffs,
360			mixing_coeff,
361		)?
362		.multilinear_extension(precompute, exec, 5),
363		6 => RingSwitchEqInd::<B64, _>::new(
364			suffix_desc.suffix.clone(),
365			row_batch_coeffs,
366			mixing_coeff,
367		)?
368		.multilinear_extension(precompute, exec, 6),
369		7 => RingSwitchEqInd::<B128, _>::new(
370			suffix_desc.suffix.clone(),
371			row_batch_coeffs,
372			mixing_coeff,
373		)?
374		.multilinear_extension(precompute, exec, 7),
375		_ => Err(Error::PackingDegreeNotSupported {
376			kappa: suffix_desc.kappa,
377		}),
378	}?;
379	Ok(eq_ind)
380}