Skip to main content

binius_field/
linear_transformation.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, marker::PhantomData, ops::BitXor};
4
5use rand::prelude::*;
6
7use crate::{
8	BinaryField, BinaryField1b, ExtensionField, UnderlierType, WithUnderlier,
9	packed::PackedBinaryField, underlier::Divisible,
10};
11
12/// Generic transformation trait that is used both for scalars and packed fields
13pub trait Transformation<Input, Output>: Sync {
14	fn transform(&self, data: &Input) -> Output;
15}
16
17/// An $\mathbb{F}_2$-linear transformation on binary fields.
18///
19/// Stores transposed transformation matrix as a collection of field elements. `Data` is a generic
20/// parameter because we want to be able both to have const instances that reference static arrays
21/// and owning vector elements.
22#[derive(Debug, Clone)]
23pub struct FieldLinearTransformation<OF: BinaryField, Data: AsRef<[OF]> + Sync = &'static [OF]> {
24	bases: Data,
25	_pd: PhantomData<OF>,
26}
27
28impl<OF: BinaryField> FieldLinearTransformation<OF, &'static [OF]> {
29	pub const fn new_const(bases: &'static [OF]) -> Self {
30		assert!(bases.len() == OF::DEGREE);
31
32		Self {
33			bases,
34			_pd: PhantomData,
35		}
36	}
37}
38
39impl<OF: BinaryField, Data: AsRef<[OF]> + Sync> FieldLinearTransformation<OF, Data> {
40	pub fn new(bases: Data) -> Self {
41		debug_assert_eq!(bases.as_ref().len(), OF::DEGREE);
42
43		Self {
44			bases,
45			_pd: PhantomData,
46		}
47	}
48
49	pub fn bases(&self) -> &[OF] {
50		self.bases.as_ref()
51	}
52}
53
54impl<IF: BinaryField, OF: BinaryField, Data: AsRef<[OF]> + Sync> Transformation<IF, OF>
55	for FieldLinearTransformation<OF, Data>
56{
57	fn transform(&self, data: &IF) -> OF {
58		assert_eq!(IF::DEGREE, OF::DEGREE);
59
60		ExtensionField::<BinaryField1b>::iter_bases(data)
61			.zip(self.bases.as_ref().iter())
62			.fold(OF::ZERO, |acc, (scalar, &basis_elem)| acc + basis_elem * scalar)
63	}
64}
65
66impl<OF: BinaryField> FieldLinearTransformation<OF, Vec<OF>> {
67	pub fn random(mut rng: impl Rng) -> Self {
68		Self {
69			bases: (0..OF::DEGREE).map(|_| OF::random(&mut rng)).collect(),
70			_pd: PhantomData,
71		}
72	}
73}
74
75const LOG_BITS_PER_BYTE: usize = 3;
76const BITS_PER_BYTE: usize = 1 << LOG_BITS_PER_BYTE;
77
78/// Linear transformation using precomputed byte-indexed lookup tables.
79///
80/// This implementation uses the [Method of Four Russians] to optimize the computation by
81/// precomputing lookup tables for each byte position and using bitwise chunks of the words.
82///
83/// [Method of Four Russians]: <https://en.wikipedia.org/wiki/Method_of_Four_Russians>
84#[derive(Debug)]
85pub struct BytewiseLookupTransformation<UIn, UOut> {
86	lookup: Vec<[UOut; 1 << BITS_PER_BYTE]>,
87	_uin_marker: PhantomData<UIn>,
88}
89
90impl<UIn, UOut> BytewiseLookupTransformation<UIn, UOut>
91where
92	UIn: UnderlierType + Divisible<u8>,
93	UOut: UnderlierType,
94{
95	pub fn new(cols: &[UOut]) -> Self {
96		assert!(LOG_BITS_PER_BYTE <= UIn::LOG_BITS);
97		assert_eq!(cols.len(), UIn::BITS);
98
99		let lookup = cols
100			.chunks(BITS_PER_BYTE)
101			.map(|cols| {
102				let cols: [_; BITS_PER_BYTE] = cols.try_into().expect(
103					"chunk size is BITS_PER_BYTE; \
104					cols.len() is a multiple of BITS_PER_BYTE",
105				);
106				expand_subset_xors(cols)
107			})
108			.collect();
109
110		Self {
111			lookup,
112			_uin_marker: PhantomData,
113		}
114	}
115}
116
117impl<UIn, UOut> Transformation<UIn, UOut> for BytewiseLookupTransformation<UIn, UOut>
118where
119	UIn: UnderlierType + Divisible<u8>,
120	UOut: UnderlierType,
121{
122	fn transform(&self, data: &UIn) -> UOut {
123		Divisible::<u8>::ref_iter(data)
124			.enumerate()
125			.take(1 << (UIn::LOG_BITS - LOG_BITS_PER_BYTE))
126			.map(|(i, byte)| {
127				// Safety:
128				// - lookup.len() == 2^(UIn::LOG_BITS - LOG_BITS_PER_BYTE) by struct invariant
129				// - take limits iteration calls to 2^(UIn::LOG_BITS - LOG_BITS_PER_BYTE)
130				let lookup = unsafe { self.lookup.get_unchecked(i) };
131				lookup[byte as usize]
132			})
133			.reduce(BitXor::bitxor)
134			.unwrap_or(UOut::ZERO)
135	}
136}
137
138fn expand_subset_xors<U: UnderlierType, const N: usize, const N_EXP2: usize>(
139	elems: [U; N],
140) -> [U; N_EXP2] {
141	assert_eq!(N_EXP2, 1 << N);
142
143	let mut expanded = [U::ZERO; N_EXP2];
144	for (i, elem_i) in elems.into_iter().enumerate() {
145		let span = &mut expanded[..1 << (i + 1)];
146		let (lo_half, hi_half) = span.split_at_mut(1 << i);
147		for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
148			*hi_half_i = *lo_half_i ^ elem_i;
149		}
150	}
151	expanded
152}
153
154/// Factory for creating bytewise lookup transformations.
155#[derive(Debug)]
156pub struct BytewiseLookupTransformationFactory;
157
158/// Factory trait for creating linear transformations from column data.
159pub trait LinearTransformationFactory<Input, Output> {
160	type Transform: Transformation<Input, Output>;
161
162	fn create(&self, cols: &[Output]) -> Self::Transform;
163}
164
165impl<UIn, UOut> LinearTransformationFactory<UIn, UOut> for BytewiseLookupTransformationFactory
166where
167	UIn: UnderlierType + Divisible<u8>,
168	UOut: UnderlierType,
169{
170	type Transform = BytewiseLookupTransformation<UIn, UOut>;
171
172	fn create(&self, cols: &[UOut]) -> Self::Transform {
173		BytewiseLookupTransformation::new(cols)
174	}
175}
176
177/// Wraps a transformation on underliers to operate on types with underliers.
178#[derive(Debug)]
179pub struct OutputWrappingTransformation<Inner, Input, Output> {
180	inner: Inner,
181	_marker: PhantomData<(Input, Output)>,
182}
183
184impl<Inner, Input, Output> Transformation<Input, Output>
185	for OutputWrappingTransformation<Inner, Input, Output>
186where
187	Inner: Transformation<Input, Output::Underlier>,
188	Input: Sync,
189	Output: WithUnderlier,
190{
191	#[inline]
192	fn transform(&self, data: &Input) -> Output {
193		Output::from_underlier(self.inner.transform(data))
194	}
195}
196
197/// Factory that wraps an underlier transformation factory to work with types that have underliers.
198#[derive(Debug)]
199pub struct OutputWrappingTransformationFactory<Inner, Input, Output> {
200	inner: Inner,
201	_marker: PhantomData<(Input, Output)>,
202}
203
204impl<Inner, Input, Output> OutputWrappingTransformationFactory<Inner, Input, Output>
205where
206	Inner: LinearTransformationFactory<Input, Output::Underlier>,
207	Input: Sync,
208	Output: WithUnderlier,
209{
210	pub fn new(inner: Inner) -> Self {
211		Self {
212			inner,
213			_marker: PhantomData,
214		}
215	}
216}
217
218impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
219	for OutputWrappingTransformationFactory<Inner, Input, Output>
220where
221	Inner: LinearTransformationFactory<Input, Output::Underlier>,
222	Input: Sync,
223	Output: WithUnderlier,
224{
225	type Transform = OutputWrappingTransformation<Inner::Transform, Input, Output>;
226
227	#[inline]
228	fn create(&self, cols: &[Output]) -> Self::Transform {
229		OutputWrappingTransformation {
230			inner: self.inner.create(Output::to_underliers_ref(cols)),
231			_marker: PhantomData,
232		}
233	}
234}
235
236/// Wraps a transformation on underliers to accept inputs with underliers.
237#[derive(Debug)]
238pub struct InputWrappingTransformation<Inner, Input, Output> {
239	inner: Inner,
240	_marker: PhantomData<(Input, Output)>,
241}
242
243impl<Inner, Input, Output> Transformation<Input, Output>
244	for InputWrappingTransformation<Inner, Input, Output>
245where
246	Inner: Transformation<Input::Underlier, Output>,
247	Input: WithUnderlier,
248	Output: Sync,
249{
250	#[inline]
251	fn transform(&self, data: &Input) -> Output {
252		self.inner.transform(&data.to_underlier())
253	}
254}
255
256/// Factory that wraps an underlier transformation factory to accept inputs with underliers.
257#[derive(Debug)]
258pub struct InputWrappingTransformationFactory<Inner, Input, Output> {
259	inner: Inner,
260	_marker: PhantomData<(Input, Output)>,
261}
262
263impl<Inner, Input, Output> InputWrappingTransformationFactory<Inner, Input, Output>
264where
265	Inner: LinearTransformationFactory<Input::Underlier, Output>,
266	Input: WithUnderlier,
267	Output: Sync,
268{
269	pub fn new(inner: Inner) -> Self {
270		Self {
271			inner,
272			_marker: PhantomData,
273		}
274	}
275}
276
277impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
278	for InputWrappingTransformationFactory<Inner, Input, Output>
279where
280	Inner: LinearTransformationFactory<Input::Underlier, Output>,
281	Input: WithUnderlier,
282	Output: Sync,
283{
284	type Transform = InputWrappingTransformation<Inner::Transform, Input, Output>;
285
286	#[inline]
287	fn create(&self, cols: &[Output]) -> Self::Transform {
288		InputWrappingTransformation {
289			inner: self.inner.create(cols),
290			_marker: PhantomData,
291		}
292	}
293}
294
295/// Transformation that wraps both input and output, converting between types with underliers.
296pub type WrappingTransformation<Inner, Input, Output> = OutputWrappingTransformation<
297	InputWrappingTransformation<Inner, Input, <Output as WithUnderlier>::Underlier>,
298	Input,
299	Output,
300>;
301
302pub struct IDTransformation;
303
304impl<OP: PackedBinaryField> Transformation<OP, OP> for IDTransformation {
305	fn transform(&self, data: &OP) -> OP {
306		*data
307	}
308}