binius_core/ring_switch/
eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, marker::PhantomData, sync::Arc};
4
5use binius_compute::{
6	ComputeLayer, ComputeLayerExecutor, ComputeMemory, SizedSlice, SubfieldSlice,
7	alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
10use binius_utils::bail;
11
12use super::error::Error;
13use crate::{
14	polynomial::{Error as PolynomialError, MultivariatePoly},
15	tensor_algebra::TensorAlgebra,
16};
17
18/// Information about the row-batching coefficients.
19#[derive(Debug)]
20pub struct RowBatchCoeffs<F> {
21	coeffs: Vec<F>,
22}
23
24impl<F: Field> RowBatchCoeffs<F> {
25	pub fn new(coeffs: Vec<F>) -> Self {
26		Self { coeffs }
27	}
28
29	pub fn coeffs(&self) -> &[F] {
30		&self.coeffs
31	}
32}
33
34/// The multilinear function $A$ from [DP24] Section 5.
35///
36/// The function $A$ is $\ell':= \ell - \kappa$-variate and depends on the last $\ell'$ coordinates
37/// of the evaluation point as well as the $\kappa$ mixing challenges.
38///
39/// [DP24]: <https://eprint.iacr.org/2024/504>
40#[derive(Debug, Clone)]
41pub struct RingSwitchEqInd<FSub, F> {
42	/// $z_{\kappa}, \ldots, z_{\ell-1}$
43	z_vals: Arc<[F]>,
44	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
45	mixing_coeff: F,
46	_marker: PhantomData<FSub>,
47}
48
49pub struct RingSwitchEqIndPrecompute<'a, F: Field, Mem: ComputeMemory<F>> {
50	evals: Mem::FSliceMut<'a>,
51	row_batching_query_expansion: Mem::FSlice<'a>,
52	mle: Mem::FSliceMut<'a>,
53}
54
55impl<FSub, F> RingSwitchEqInd<FSub, F>
56where
57	FSub: Field,
58	F: ExtensionField<FSub>,
59{
60	pub fn new(
61		z_vals: Arc<[F]>,
62		row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
63		mixing_coeff: F,
64	) -> Result<Self, Error> {
65		if row_batch_coeffs.coeffs.len() < F::DEGREE {
66			bail!(Error::InvalidArgs(
67				"RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \
68				the extension degree"
69					.into()
70			));
71		}
72
73		Ok(Self {
74			z_vals,
75			row_batch_coeffs,
76			mixing_coeff,
77			_marker: PhantomData,
78		})
79	}
80
81	pub fn precompute_values<'a, Hal: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>(
82		z_vals: Arc<[F]>,
83		row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
84		mixing_coeff: F,
85		kappa: usize,
86		hal: &Hal,
87		dev_alloc: &'a DeviceAllocatorType,
88		host_alloc: &HostAllocatorType,
89	) -> Result<RingSwitchEqIndPrecompute<'a, F, Hal::DevMem>, Error>
90	where
91		HostAllocatorType: ComputeAllocator<F, CpuMemory>,
92		DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
93	{
94		let extension_degree = 1 << (kappa);
95
96		let mut row_batching_query_expansion = dev_alloc.alloc(extension_degree)?;
97
98		hal.copy_h2d(
99			&row_batch_coeffs.coeffs()[0..extension_degree],
100			&mut row_batching_query_expansion,
101		)?;
102
103		let row_batching_query_expansion = Hal::DevMem::to_const(row_batching_query_expansion);
104
105		let n_vars = z_vals.len();
106		let mut evals = dev_alloc.alloc(1 << n_vars)?;
107
108		{
109			let host_val = host_alloc.alloc(1)?;
110			host_val[0] = mixing_coeff;
111			let mut dev_val = Hal::DevMem::slice_power_of_two_mut(&mut evals, 1);
112			hal.copy_h2d(host_val, &mut dev_val)?;
113		}
114
115		let mle = dev_alloc.alloc(evals.len())?;
116
117		Ok(RingSwitchEqIndPrecompute {
118			evals,
119			row_batching_query_expansion,
120			mle,
121		})
122	}
123
124	pub fn multilinear_extension<
125		'a,
126		Mem: ComputeMemory<F>,
127		Exec: ComputeLayerExecutor<F, DevMem = Mem>,
128	>(
129		&self,
130		precompute: RingSwitchEqIndPrecompute<'a, F, Mem>,
131		exec: &mut Exec,
132		tower_level: usize,
133	) -> Result<Mem::FSlice<'a>, Error> {
134		let RingSwitchEqIndPrecompute {
135			mut evals,
136			row_batching_query_expansion,
137			mut mle,
138		} = precompute;
139
140		exec.tensor_expand(0, &self.z_vals, &mut evals)?;
141
142		let subfield_vector = SubfieldSlice::new(Mem::as_const(&evals), tower_level);
143
144		exec.fold_right(subfield_vector, row_batching_query_expansion, &mut mle)?;
145
146		Ok(Mem::to_const(mle))
147	}
148}
149
150impl<FSub, F> MultivariatePoly<F> for RingSwitchEqInd<FSub, F>
151where
152	FSub: TowerField,
153	F: TowerField + PackedField<Scalar = F> + PackedExtension<FSub>,
154{
155	fn n_vars(&self) -> usize {
156		self.z_vals.len()
157	}
158
159	fn degree(&self) -> usize {
160		self.n_vars()
161	}
162
163	fn evaluate(&self, query: &[F]) -> Result<F, PolynomialError> {
164		if query.len() != self.n_vars() {
165			bail!(PolynomialError::IncorrectQuerySize {
166				expected: self.n_vars(),
167				actual: query.len(),
168			});
169		};
170
171		let tensor_eval = iter::zip(&*self.z_vals, query).fold(
172			<TensorAlgebra<FSub, F>>::from_vertical(self.mixing_coeff),
173			|eval, (&vert_i, &hztl_i)| {
174				// This formula is specific to characteristic 2 fields
175				// Here we know that $h v + (1 - h) (1 - v) = 1 + h + v$.
176				let vert_scaled = eval.clone().scale_vertical(vert_i);
177				let hztl_scaled = eval.clone().scale_horizontal(hztl_i);
178				eval + &vert_scaled + &hztl_scaled
179			},
180		);
181
182		let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs);
183		Ok(folded_eval)
184	}
185
186	fn binary_tower_level(&self) -> usize {
187		F::TOWER_LEVEL
188	}
189}
190
191#[cfg(test)]
192mod tests {
193	use binius_compute::{ComputeData, ComputeHolder, cpu::layer::CpuLayerHolder};
194	use binius_field::{BinaryField8b, BinaryField128b};
195	use binius_math::{MultilinearQuery, eq_ind_partial_eval};
196	use iter::repeat_with;
197	use rand::{SeedableRng, prelude::StdRng};
198
199	use super::*;
200
201	#[test]
202	fn test_evaluation_consistency() {
203		type FS = BinaryField8b;
204		type F = BinaryField128b;
205		let kappa = <TensorAlgebra<FS, F>>::kappa();
206		let ell = 10;
207		let mut rng = StdRng::seed_from_u64(0);
208
209		let n_vars = ell - kappa;
210		let z_vals = repeat_with(|| <F as Field>::random(&mut rng))
211			.take(n_vars)
212			.collect::<Arc<[_]>>();
213
214		let row_batch_challenges = repeat_with(|| <F as Field>::random(&mut rng))
215			.take(kappa)
216			.collect::<Vec<_>>();
217
218		let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
219			MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
220		));
221
222		let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
223			.take(n_vars)
224			.collect::<Vec<_>>();
225
226		let mixing_coeff = <F as Field>::random(&mut rng);
227
228		let mut compute_holder = CpuLayerHolder::new(1 << 10, 1 << 10);
229
230		let compute_data = compute_holder.to_data();
231
232		let ComputeData {
233			hal,
234			dev_alloc,
235			host_alloc,
236			..
237		} = compute_data;
238
239		let precompute = RingSwitchEqInd::<FS, _>::precompute_values(
240			z_vals.clone(),
241			row_batch_coeffs.clone(),
242			mixing_coeff,
243			kappa,
244			hal,
245			&dev_alloc,
246			&host_alloc,
247		)
248		.unwrap();
249
250		let rs_eq = RingSwitchEqInd::<FS, _>::new(z_vals, row_batch_coeffs, mixing_coeff).unwrap();
251
252		let val1 = rs_eq.evaluate(&eval_point).unwrap();
253
254		hal.execute(|exec| {
255			let mle = rs_eq
256				.multilinear_extension(precompute, exec, BinaryField8b::TOWER_LEVEL)
257				.unwrap();
258
259			let mle = SubfieldSlice::new(mle, BinaryField128b::TOWER_LEVEL);
260
261			let query = eq_ind_partial_eval(&eval_point);
262
263			let val2 = exec.inner_product(mle, &query).unwrap();
264
265			assert_eq!(val1, val2);
266			Ok(vec![])
267		})
268		.unwrap();
269	}
270}