1use std::{iter, marker::PhantomData, ops::BitXor};
4
5use rand::RngCore;
6
7use crate::{
8 BinaryField, BinaryField1b, ExtensionField, UnderlierWithBitOps, WithUnderlier,
9 packed::PackedBinaryField,
10 underlier::{DivisIterable, UnderlierType},
11};
12
13pub trait Transformation<Input, Output>: Sync {
15 fn transform(&self, data: &Input) -> Output;
16}
17
18#[derive(Debug, Clone)]
24pub struct FieldLinearTransformation<OF: BinaryField, Data: AsRef<[OF]> + Sync = &'static [OF]> {
25 bases: Data,
26 _pd: PhantomData<OF>,
27}
28
29impl<OF: BinaryField> FieldLinearTransformation<OF, &'static [OF]> {
30 pub const fn new_const(bases: &'static [OF]) -> Self {
31 assert!(bases.len() == OF::DEGREE);
32
33 Self {
34 bases,
35 _pd: PhantomData,
36 }
37 }
38}
39
40impl<OF: BinaryField, Data: AsRef<[OF]> + Sync> FieldLinearTransformation<OF, Data> {
41 pub fn new(bases: Data) -> Self {
42 debug_assert_eq!(bases.as_ref().len(), OF::DEGREE);
43
44 Self {
45 bases,
46 _pd: PhantomData,
47 }
48 }
49
50 pub fn bases(&self) -> &[OF] {
51 self.bases.as_ref()
52 }
53}
54
55impl<IF: BinaryField, OF: BinaryField, Data: AsRef<[OF]> + Sync> Transformation<IF, OF>
56 for FieldLinearTransformation<OF, Data>
57{
58 fn transform(&self, data: &IF) -> OF {
59 assert_eq!(IF::DEGREE, OF::DEGREE);
60
61 ExtensionField::<BinaryField1b>::iter_bases(data)
62 .zip(self.bases.as_ref().iter())
63 .fold(OF::ZERO, |acc, (scalar, &basis_elem)| acc + basis_elem * scalar)
64 }
65}
66
67impl<OF: BinaryField> FieldLinearTransformation<OF, Vec<OF>> {
68 pub fn random(mut rng: impl RngCore) -> Self {
69 Self {
70 bases: (0..OF::DEGREE).map(|_| OF::random(&mut rng)).collect(),
71 _pd: PhantomData,
72 }
73 }
74}
75
76const LOG_BITS_PER_BYTE: usize = 3;
77const BITS_PER_BYTE: usize = 1 << LOG_BITS_PER_BYTE;
78
79#[derive(Debug)]
86pub struct BytewiseLookupTransformation<UIn, UOut> {
87 lookup: Vec<[UOut; 1 << BITS_PER_BYTE]>,
88 _uin_marker: PhantomData<UIn>,
89}
90
91impl<UIn, UOut> BytewiseLookupTransformation<UIn, UOut>
92where
93 UIn: UnderlierType + DivisIterable<u8>,
94 UOut: UnderlierWithBitOps,
95{
96 pub fn new(cols: &[UOut]) -> Self {
97 assert!(LOG_BITS_PER_BYTE <= UIn::LOG_BITS);
98 assert_eq!(cols.len(), UIn::BITS);
99
100 let lookup = cols
101 .chunks(BITS_PER_BYTE)
102 .map(|cols| {
103 let cols: [_; BITS_PER_BYTE] = cols.try_into().expect(
104 "chunk size is BITS_PER_BYTE; \
105 cols.len() is a multiple of BITS_PER_BYTE",
106 );
107 expand_subset_xors(cols)
108 })
109 .collect();
110
111 Self {
112 lookup,
113 _uin_marker: PhantomData,
114 }
115 }
116}
117
118impl<UIn, UOut> Transformation<UIn, UOut> for BytewiseLookupTransformation<UIn, UOut>
119where
120 UIn: UnderlierType + DivisIterable<u8>,
121 UOut: UnderlierWithBitOps,
122{
123 fn transform(&self, data: &UIn) -> UOut {
124 data.divide()
125 .enumerate()
126 .take(1 << (UIn::LOG_BITS - LOG_BITS_PER_BYTE))
127 .map(|(i, byte)| {
128 let lookup = unsafe { self.lookup.get_unchecked(i) };
132 lookup[*byte as usize]
133 })
134 .reduce(BitXor::bitxor)
135 .unwrap_or(UOut::ZERO)
136 }
137}
138
139fn expand_subset_xors<U: UnderlierWithBitOps, const N: usize, const N_EXP2: usize>(
140 elems: [U; N],
141) -> [U; N_EXP2] {
142 assert_eq!(N_EXP2, 1 << N);
143
144 let mut expanded = [U::ZERO; N_EXP2];
145 for (i, elem_i) in elems.into_iter().enumerate() {
146 let span = &mut expanded[..1 << (i + 1)];
147 let (lo_half, hi_half) = span.split_at_mut(1 << i);
148 for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
149 *hi_half_i = *lo_half_i ^ elem_i;
150 }
151 }
152 expanded
153}
154
155#[derive(Debug)]
157pub struct BytewiseLookupTransformationFactory;
158
159pub trait LinearTransformationFactory<Input, Output> {
161 type Transform: Transformation<Input, Output>;
162
163 fn create(&self, cols: &[Output]) -> Self::Transform;
164}
165
166impl<UIn, UOut> LinearTransformationFactory<UIn, UOut> for BytewiseLookupTransformationFactory
167where
168 UIn: UnderlierType + DivisIterable<u8>,
169 UOut: UnderlierWithBitOps,
170{
171 type Transform = BytewiseLookupTransformation<UIn, UOut>;
172
173 fn create(&self, cols: &[UOut]) -> Self::Transform {
174 BytewiseLookupTransformation::new(cols)
175 }
176}
177
178#[derive(Debug)]
180pub struct OutputWrappingTransformation<Inner, Input, Output> {
181 inner: Inner,
182 _marker: PhantomData<(Input, Output)>,
183}
184
185impl<Inner, Input, Output> Transformation<Input, Output>
186 for OutputWrappingTransformation<Inner, Input, Output>
187where
188 Inner: Transformation<Input, Output::Underlier>,
189 Input: Sync,
190 Output: WithUnderlier,
191{
192 #[inline]
193 fn transform(&self, data: &Input) -> Output {
194 Output::from_underlier(self.inner.transform(data))
195 }
196}
197
198#[derive(Debug)]
200pub struct OutputWrappingTransformationFactory<Inner, Input, Output> {
201 inner: Inner,
202 _marker: PhantomData<(Input, Output)>,
203}
204
205impl<Inner, Input, Output> OutputWrappingTransformationFactory<Inner, Input, Output>
206where
207 Inner: LinearTransformationFactory<Input, Output::Underlier>,
208 Input: Sync,
209 Output: WithUnderlier,
210{
211 pub fn new(inner: Inner) -> Self {
212 Self {
213 inner,
214 _marker: PhantomData,
215 }
216 }
217}
218
219impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
220 for OutputWrappingTransformationFactory<Inner, Input, Output>
221where
222 Inner: LinearTransformationFactory<Input, Output::Underlier>,
223 Input: Sync,
224 Output: WithUnderlier,
225{
226 type Transform = OutputWrappingTransformation<Inner::Transform, Input, Output>;
227
228 #[inline]
229 fn create(&self, cols: &[Output]) -> Self::Transform {
230 OutputWrappingTransformation {
231 inner: self.inner.create(Output::to_underliers_ref(cols)),
232 _marker: PhantomData,
233 }
234 }
235}
236
237#[derive(Debug)]
239pub struct InputWrappingTransformation<Inner, Input, Output> {
240 inner: Inner,
241 _marker: PhantomData<(Input, Output)>,
242}
243
244impl<Inner, Input, Output> Transformation<Input, Output>
245 for InputWrappingTransformation<Inner, Input, Output>
246where
247 Inner: Transformation<Input::Underlier, Output>,
248 Input: WithUnderlier,
249 Output: Sync,
250{
251 #[inline]
252 fn transform(&self, data: &Input) -> Output {
253 self.inner.transform(&data.to_underlier())
254 }
255}
256
257#[derive(Debug)]
259pub struct InputWrappingTransformationFactory<Inner, Input, Output> {
260 inner: Inner,
261 _marker: PhantomData<(Input, Output)>,
262}
263
264impl<Inner, Input, Output> InputWrappingTransformationFactory<Inner, Input, Output>
265where
266 Inner: LinearTransformationFactory<Input::Underlier, Output>,
267 Input: WithUnderlier,
268 Output: Sync,
269{
270 pub fn new(inner: Inner) -> Self {
271 Self {
272 inner,
273 _marker: PhantomData,
274 }
275 }
276}
277
278impl<Inner, Input, Output> LinearTransformationFactory<Input, Output>
279 for InputWrappingTransformationFactory<Inner, Input, Output>
280where
281 Inner: LinearTransformationFactory<Input::Underlier, Output>,
282 Input: WithUnderlier,
283 Output: Sync,
284{
285 type Transform = InputWrappingTransformation<Inner::Transform, Input, Output>;
286
287 #[inline]
288 fn create(&self, cols: &[Output]) -> Self::Transform {
289 InputWrappingTransformation {
290 inner: self.inner.create(cols),
291 _marker: PhantomData,
292 }
293 }
294}
295
296pub type WrappingTransformation<Inner, Input, Output> = OutputWrappingTransformation<
298 InputWrappingTransformation<Inner, Input, <Output as WithUnderlier>::Underlier>,
299 Input,
300 Output,
301>;
302
303pub trait PackedTransformationFactory<OP>: PackedBinaryField
306where
307 OP: PackedBinaryField,
308{
309 type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync>: Transformation<Self, OP>;
310
311 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
312 transformation: FieldLinearTransformation<OP::Scalar, Data>,
313 ) -> Self::PackedTransformation<Data>;
314}
315
316pub struct IDTransformation;
317
318impl<OP: PackedBinaryField> Transformation<OP, OP> for IDTransformation {
319 fn transform(&self, data: &OP) -> OP {
320 *data
321 }
322}