1use std::{iter, marker::PhantomData, ops::BitXor};
4
5use rand::prelude::*;
6
7use crate::{
8 BinaryField, BinaryField1b, ExtensionField, UnderlierType, WithUnderlier,
9 packed::PackedBinaryField, underlier::Divisible,
10};
11
12pub trait Transformation<Input, Output>: Sync {
14 fn transform(&self, data: &Input) -> Output;
15}
16
17#[derive(Debug, Clone)]
23pub struct FieldLinearTransformation<OF: BinaryField, Data: AsRef<[OF]> + Sync = &'static [OF]> {
24 bases: Data,
25 _pd: PhantomData<OF>,
26}
27
28impl<OF: BinaryField> FieldLinearTransformation<OF, &'static [OF]> {
29 pub const fn new_const(bases: &'static [OF]) -> Self {
30 assert!(bases.len() == OF::DEGREE);
31
32 Self {
33 bases,
34 _pd: PhantomData,
35 }
36 }
37}
38
39impl<OF: BinaryField, Data: AsRef<[OF]> + Sync> FieldLinearTransformation<OF, Data> {
40 pub fn new(bases: Data) -> Self {
41 debug_assert_eq!(bases.as_ref().len(), OF::DEGREE);
42
43 Self {
44 bases,
45 _pd: PhantomData,
46 }
47 }
48
49 pub fn bases(&self) -> &[OF] {
50 self.bases.as_ref()
51 }
52}
53
54impl<IF: BinaryField, OF: BinaryField, Data: AsRef<[OF]> + Sync> Transformation<IF, OF>
55 for FieldLinearTransformation<OF, Data>
56{
57 fn transform(&self, data: &IF) -> OF {
58 assert_eq!(IF::DEGREE, OF::DEGREE);
59
60 ExtensionField::<BinaryField1b>::iter_bases(data)
61 .zip(self.bases.as_ref().iter())
62 .fold(OF::ZERO, |acc, (scalar, &basis_elem)| acc + basis_elem * scalar)
63 }
64}
65
66impl<OF: BinaryField> FieldLinearTransformation<OF, Vec<OF>> {
67 pub fn random(mut rng: impl Rng) -> Self {
68 Self {
69 bases: (0..OF::DEGREE).map(|_| OF::random(&mut rng)).collect(),
70 _pd: PhantomData,
71 }
72 }
73}
74
75const LOG_BITS_PER_BYTE: usize = 3;
76const BITS_PER_BYTE: usize = 1 << LOG_BITS_PER_BYTE;
77
78#[derive(Debug)]
85pub struct BytewiseLookupTransformation<UIn, UOut> {
86 lookup: Vec<[UOut; 1 << BITS_PER_BYTE]>,
87 _uin_marker: PhantomData<UIn>,
88}
89
90impl<UIn, UOut> BytewiseLookupTransformation<UIn, UOut>
91where
92 UIn: UnderlierType + Divisible<u8>,
93 UOut: UnderlierType,
94{
95 pub fn new(cols: &[UOut]) -> Self {
96 assert!(LOG_BITS_PER_BYTE <= UIn::LOG_BITS);
97 assert_eq!(cols.len(), UIn::BITS);
98
99 let lookup = cols
100 .chunks(BITS_PER_BYTE)
101 .map(|cols| {
102 let cols: [_; BITS_PER_BYTE] = cols.try_into().expect(
103 "chunk size is BITS_PER_BYTE; \
104 cols.len() is a multiple of BITS_PER_BYTE",
105 );
106 expand_subset_xors(cols)
107 })
108 .collect();
109
110 Self {
111 lookup,
112 _uin_marker: PhantomData,
113 }
114 }
115}
116
117impl<UIn, UOut> Transformation<UIn, UOut> for BytewiseLookupTransformation<UIn, UOut>
118where
119 UIn: UnderlierType + Divisible<u8>,
120 UOut: UnderlierType,
121{
122 fn transform(&self, data: &UIn) -> UOut {
123 Divisible::<u8>::ref_iter(data)
124 .enumerate()
125 .take(1 << (UIn::LOG_BITS - LOG_BITS_PER_BYTE))
126 .map(|(i, byte)| {
127 let lookup = unsafe { self.lookup.get_unchecked(i) };
131 lookup[byte as usize]
132 })
133 .reduce(BitXor::bitxor)
134 .unwrap_or(UOut::ZERO)
135 }
136}
137
138fn expand_subset_xors<U: UnderlierType, const N: usize, const N_EXP2: usize>(
139 elems: [U; N],
140) -> [U; N_EXP2] {
141 assert_eq!(N_EXP2, 1 << N);
142
143 let mut expanded = [U::ZERO; N_EXP2];
144 for (i, elem_i) in elems.into_iter().enumerate() {
145 let span = &mut expanded[..1 << (i + 1)];
146 let (lo_half, hi_half) = span.split_at_mut(1 << i);
147 for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
148 *hi_half_i = *lo_half_i ^ elem_i;
149 }
150 }
151 expanded
152}
153
154#[derive(Debug)]
156pub struct BytewiseLookupTransformationFactory;
157
158pub trait LinearTransformationFactory<Input, Output> {
160 type Transform: Transformation<Input, Output>;
161
162 fn create(&self, cols: &[Output]) -> Self::Transform;
163}
164
165impl<UIn, UOut> LinearTransformationFactory<UIn, UOut> for BytewiseLookupTransformationFactory
166where
167 UIn: UnderlierType + Divisible<u8>,
168 UOut: UnderlierType,
169{
170 type Transform = BytewiseLookupTransformation<UIn, UOut>;
171
172 fn create(&self, cols: &[UOut]) -> Self::Transform {
173 BytewiseLookupTransformation::new(cols)
174 }
175}
176
177#[derive(Debug)]
179pub struct OutputWrappingTransformation<Inner, Input, Output> {
180 inner: Inner,
181 _marker: PhantomData<(Input, Output)>,
182}
183
184impl<Inner, Input, Output> Transformation<Input, Output>
185 for OutputWrappingTransformation<Inner, Input, Output>
186where
187 Inner: Transformation<Input, Output::Underlier>,
188 Input: Sync,
189 Output: WithUnderlier,
190{
191 #[inline]
192 fn transform(&self, data: &Input) -> Output {
193 Output::from_underlier(self.inner.transform(data))
194 }
195}
196
197#[derive(Debug)]
199pub struct OutputWrappingTransformationFactory<Inner, Input, Output> {
200 inner: Inner,
201 _marker: PhantomData<(Input, Output)>,
202}
203
204impl<Inner, Input, Output> OutputWrappingTransformationFactory<Inner, Input, Output>
205where
206 Inner: LinearTransformationFactory<Input, Output::Underlier>,
207 Input: Sync,
208 Output: WithUnderlier,
209{
210 pub fn new(inner: Inner) -> Self {
211 Self {
212 inner,
213 _marker: PhantomData,
214 }
215 }
216}
217
218impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
219 for OutputWrappingTransformationFactory<Inner, Input, Output>
220where
221 Inner: LinearTransformationFactory<Input, Output::Underlier>,
222 Input: Sync,
223 Output: WithUnderlier,
224{
225 type Transform = OutputWrappingTransformation<Inner::Transform, Input, Output>;
226
227 #[inline]
228 fn create(&self, cols: &[Output]) -> Self::Transform {
229 OutputWrappingTransformation {
230 inner: self.inner.create(Output::to_underliers_ref(cols)),
231 _marker: PhantomData,
232 }
233 }
234}
235
236#[derive(Debug)]
238pub struct InputWrappingTransformation<Inner, Input, Output> {
239 inner: Inner,
240 _marker: PhantomData<(Input, Output)>,
241}
242
243impl<Inner, Input, Output> Transformation<Input, Output>
244 for InputWrappingTransformation<Inner, Input, Output>
245where
246 Inner: Transformation<Input::Underlier, Output>,
247 Input: WithUnderlier,
248 Output: Sync,
249{
250 #[inline]
251 fn transform(&self, data: &Input) -> Output {
252 self.inner.transform(&data.to_underlier())
253 }
254}
255
256#[derive(Debug)]
258pub struct InputWrappingTransformationFactory<Inner, Input, Output> {
259 inner: Inner,
260 _marker: PhantomData<(Input, Output)>,
261}
262
263impl<Inner, Input, Output> InputWrappingTransformationFactory<Inner, Input, Output>
264where
265 Inner: LinearTransformationFactory<Input::Underlier, Output>,
266 Input: WithUnderlier,
267 Output: Sync,
268{
269 pub fn new(inner: Inner) -> Self {
270 Self {
271 inner,
272 _marker: PhantomData,
273 }
274 }
275}
276
277impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
278 for InputWrappingTransformationFactory<Inner, Input, Output>
279where
280 Inner: LinearTransformationFactory<Input::Underlier, Output>,
281 Input: WithUnderlier,
282 Output: Sync,
283{
284 type Transform = InputWrappingTransformation<Inner::Transform, Input, Output>;
285
286 #[inline]
287 fn create(&self, cols: &[Output]) -> Self::Transform {
288 InputWrappingTransformation {
289 inner: self.inner.create(cols),
290 _marker: PhantomData,
291 }
292 }
293}
294
295pub type WrappingTransformation<Inner, Input, Output> = OutputWrappingTransformation<
297 InputWrappingTransformation<Inner, Input, <Output as WithUnderlier>::Underlier>,
298 Input,
299 Output,
300>;
301
302pub struct IDTransformation;
303
304impl<OP: PackedBinaryField> Transformation<OP, OP> for IDTransformation {
305 fn transform(&self, data: &OP) -> OP {
306 *data
307 }
308}