binius_field/
linear_transformation.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, marker::PhantomData, ops::BitXor};
4
5use rand::RngCore;
6
7use crate::{
8	BinaryField, BinaryField1b, ExtensionField, UnderlierWithBitOps, WithUnderlier,
9	packed::PackedBinaryField,
10	underlier::{DivisIterable, UnderlierType},
11};
12
13/// Generic transformation trait that is used both for scalars and packed fields
14pub trait Transformation<Input, Output>: Sync {
15	fn transform(&self, data: &Input) -> Output;
16}
17
18/// An $\mathbb{F}_2$-linear transformation on binary fields.
19///
20/// Stores transposed transformation matrix as a collection of field elements. `Data` is a generic
21/// parameter because we want to be able both to have const instances that reference static arrays
22/// and owning vector elements.
23#[derive(Debug, Clone)]
24pub struct FieldLinearTransformation<OF: BinaryField, Data: AsRef<[OF]> + Sync = &'static [OF]> {
25	bases: Data,
26	_pd: PhantomData<OF>,
27}
28
29impl<OF: BinaryField> FieldLinearTransformation<OF, &'static [OF]> {
30	pub const fn new_const(bases: &'static [OF]) -> Self {
31		assert!(bases.len() == OF::DEGREE);
32
33		Self {
34			bases,
35			_pd: PhantomData,
36		}
37	}
38}
39
40impl<OF: BinaryField, Data: AsRef<[OF]> + Sync> FieldLinearTransformation<OF, Data> {
41	pub fn new(bases: Data) -> Self {
42		debug_assert_eq!(bases.as_ref().len(), OF::DEGREE);
43
44		Self {
45			bases,
46			_pd: PhantomData,
47		}
48	}
49
50	pub fn bases(&self) -> &[OF] {
51		self.bases.as_ref()
52	}
53}
54
55impl<IF: BinaryField, OF: BinaryField, Data: AsRef<[OF]> + Sync> Transformation<IF, OF>
56	for FieldLinearTransformation<OF, Data>
57{
58	fn transform(&self, data: &IF) -> OF {
59		assert_eq!(IF::DEGREE, OF::DEGREE);
60
61		ExtensionField::<BinaryField1b>::iter_bases(data)
62			.zip(self.bases.as_ref().iter())
63			.fold(OF::ZERO, |acc, (scalar, &basis_elem)| acc + basis_elem * scalar)
64	}
65}
66
67impl<OF: BinaryField> FieldLinearTransformation<OF, Vec<OF>> {
68	pub fn random(mut rng: impl RngCore) -> Self {
69		Self {
70			bases: (0..OF::DEGREE).map(|_| OF::random(&mut rng)).collect(),
71			_pd: PhantomData,
72		}
73	}
74}
75
76const LOG_BITS_PER_BYTE: usize = 3;
77const BITS_PER_BYTE: usize = 1 << LOG_BITS_PER_BYTE;
78
79/// Linear transformation using precomputed byte-indexed lookup tables.
80///
81/// This implementation uses the [Method of Four Russians] to optimize the computation by
82/// precomputing lookup tables for each byte position and using bitwise chunks of the words.
83///
84/// [Method of Four Russians]: <https://en.wikipedia.org/wiki/Method_of_Four_Russians>
85#[derive(Debug)]
86pub struct BytewiseLookupTransformation<UIn, UOut> {
87	lookup: Vec<[UOut; 1 << BITS_PER_BYTE]>,
88	_uin_marker: PhantomData<UIn>,
89}
90
91impl<UIn, UOut> BytewiseLookupTransformation<UIn, UOut>
92where
93	UIn: UnderlierType + DivisIterable<u8>,
94	UOut: UnderlierWithBitOps,
95{
96	pub fn new(cols: &[UOut]) -> Self {
97		assert!(LOG_BITS_PER_BYTE <= UIn::LOG_BITS);
98		assert_eq!(cols.len(), UIn::BITS);
99
100		let lookup = cols
101			.chunks(BITS_PER_BYTE)
102			.map(|cols| {
103				let cols: [_; BITS_PER_BYTE] = cols.try_into().expect(
104					"chunk size is BITS_PER_BYTE; \
105					cols.len() is a multiple of BITS_PER_BYTE",
106				);
107				expand_subset_xors(cols)
108			})
109			.collect();
110
111		Self {
112			lookup,
113			_uin_marker: PhantomData,
114		}
115	}
116}
117
118impl<UIn, UOut> Transformation<UIn, UOut> for BytewiseLookupTransformation<UIn, UOut>
119where
120	UIn: UnderlierType + DivisIterable<u8>,
121	UOut: UnderlierWithBitOps,
122{
123	fn transform(&self, data: &UIn) -> UOut {
124		data.divide()
125			.enumerate()
126			.take(1 << (UIn::LOG_BITS - LOG_BITS_PER_BYTE))
127			.map(|(i, byte)| {
128				// Safety:
129				// - lookup.len() == 2^(UIn::LOG_BITS - LOG_BITS_PER_BYTE) by struct invariant
130				// - take limits iteration calls to 2^(UIn::LOG_BITS - LOG_BITS_PER_BYTE)
131				let lookup = unsafe { self.lookup.get_unchecked(i) };
132				lookup[*byte as usize]
133			})
134			.reduce(BitXor::bitxor)
135			.unwrap_or(UOut::ZERO)
136	}
137}
138
139fn expand_subset_xors<U: UnderlierWithBitOps, const N: usize, const N_EXP2: usize>(
140	elems: [U; N],
141) -> [U; N_EXP2] {
142	assert_eq!(N_EXP2, 1 << N);
143
144	let mut expanded = [U::ZERO; N_EXP2];
145	for (i, elem_i) in elems.into_iter().enumerate() {
146		let span = &mut expanded[..1 << (i + 1)];
147		let (lo_half, hi_half) = span.split_at_mut(1 << i);
148		for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
149			*hi_half_i = *lo_half_i ^ elem_i;
150		}
151	}
152	expanded
153}
154
155/// Factory for creating bytewise lookup transformations.
156#[derive(Debug)]
157pub struct BytewiseLookupTransformationFactory;
158
159/// Factory trait for creating linear transformations from column data.
160pub trait LinearTransformationFactory<Input, Output> {
161	type Transform: Transformation<Input, Output>;
162
163	fn create(&self, cols: &[Output]) -> Self::Transform;
164}
165
166impl<UIn, UOut> LinearTransformationFactory<UIn, UOut> for BytewiseLookupTransformationFactory
167where
168	UIn: UnderlierType + DivisIterable<u8>,
169	UOut: UnderlierWithBitOps,
170{
171	type Transform = BytewiseLookupTransformation<UIn, UOut>;
172
173	fn create(&self, cols: &[UOut]) -> Self::Transform {
174		BytewiseLookupTransformation::new(cols)
175	}
176}
177
178/// Wraps a transformation on underliers to operate on types with underliers.
179#[derive(Debug)]
180pub struct OutputWrappingTransformation<Inner, Input, Output> {
181	inner: Inner,
182	_marker: PhantomData<(Input, Output)>,
183}
184
185impl<Inner, Input, Output> Transformation<Input, Output>
186	for OutputWrappingTransformation<Inner, Input, Output>
187where
188	Inner: Transformation<Input, Output::Underlier>,
189	Input: Sync,
190	Output: WithUnderlier,
191{
192	#[inline]
193	fn transform(&self, data: &Input) -> Output {
194		Output::from_underlier(self.inner.transform(data))
195	}
196}
197
198/// Factory that wraps an underlier transformation factory to work with types that have underliers.
199#[derive(Debug)]
200pub struct OutputWrappingTransformationFactory<Inner, Input, Output> {
201	inner: Inner,
202	_marker: PhantomData<(Input, Output)>,
203}
204
205impl<Inner, Input, Output> OutputWrappingTransformationFactory<Inner, Input, Output>
206where
207	Inner: LinearTransformationFactory<Input, Output::Underlier>,
208	Input: Sync,
209	Output: WithUnderlier,
210{
211	pub fn new(inner: Inner) -> Self {
212		Self {
213			inner,
214			_marker: PhantomData,
215		}
216	}
217}
218
219impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
220	for OutputWrappingTransformationFactory<Inner, Input, Output>
221where
222	Inner: LinearTransformationFactory<Input, Output::Underlier>,
223	Input: Sync,
224	Output: WithUnderlier,
225{
226	type Transform = OutputWrappingTransformation<Inner::Transform, Input, Output>;
227
228	#[inline]
229	fn create(&self, cols: &[Output]) -> Self::Transform {
230		OutputWrappingTransformation {
231			inner: self.inner.create(Output::to_underliers_ref(cols)),
232			_marker: PhantomData,
233		}
234	}
235}
236
237/// Wraps a transformation on underliers to accept inputs with underliers.
238#[derive(Debug)]
239pub struct InputWrappingTransformation<Inner, Input, Output> {
240	inner: Inner,
241	_marker: PhantomData<(Input, Output)>,
242}
243
244impl<Inner, Input, Output> Transformation<Input, Output>
245	for InputWrappingTransformation<Inner, Input, Output>
246where
247	Inner: Transformation<Input::Underlier, Output>,
248	Input: WithUnderlier,
249	Output: Sync,
250{
251	#[inline]
252	fn transform(&self, data: &Input) -> Output {
253		self.inner.transform(&data.to_underlier())
254	}
255}
256
257/// Factory that wraps an underlier transformation factory to accept inputs with underliers.
258#[derive(Debug)]
259pub struct InputWrappingTransformationFactory<Inner, Input, Output> {
260	inner: Inner,
261	_marker: PhantomData<(Input, Output)>,
262}
263
264impl<Inner, Input, Output> InputWrappingTransformationFactory<Inner, Input, Output>
265where
266	Inner: LinearTransformationFactory<Input::Underlier, Output>,
267	Input: WithUnderlier,
268	Output: Sync,
269{
270	pub fn new(inner: Inner) -> Self {
271		Self {
272			inner,
273			_marker: PhantomData,
274		}
275	}
276}
277
278impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
279	for InputWrappingTransformationFactory<Inner, Input, Output>
280where
281	Inner: LinearTransformationFactory<Input::Underlier, Output>,
282	Input: WithUnderlier,
283	Output: Sync,
284{
285	type Transform = InputWrappingTransformation<Inner::Transform, Input, Output>;
286
287	#[inline]
288	fn create(&self, cols: &[Output]) -> Self::Transform {
289		InputWrappingTransformation {
290			inner: self.inner.create(cols),
291			_marker: PhantomData,
292		}
293	}
294}
295
296/// Transformation that wraps both input and output, converting between types with underliers.
297pub type WrappingTransformation<Inner, Input, Output> = OutputWrappingTransformation<
298	InputWrappingTransformation<Inner, Input, <Output as WithUnderlier>::Underlier>,
299	Input,
300	Output,
301>;
302
303/// This crates represents a type that creates a packed transformation from `Self` to a packed
304/// field based on the scalar field transformation.
305pub trait PackedTransformationFactory<OP>: PackedBinaryField
306where
307	OP: PackedBinaryField,
308{
309	type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync>: Transformation<Self, OP>;
310
311	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
312		transformation: FieldLinearTransformation<OP::Scalar, Data>,
313	) -> Self::PackedTransformation<Data>;
314}
315
316pub struct IDTransformation;
317
318impl<OP: PackedBinaryField> Transformation<OP, OP> for IDTransformation {
319	fn transform(&self, data: &OP) -> OP {
320		*data
321	}
322}