1use std::{
5 fmt::{Debug, Display, Formatter},
6 iter::{Product, Sum},
7 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use binius_utils::{
11 DeserializeBytes, FixedSizeSerializeBytes, SerializationError, SerializeBytes,
12 bytes::{Buf, BufMut},
13};
14use bytemuck::Zeroable;
15
16use super::{UnderlierType, WithUnderlier, extension::ExtensionField};
17use crate::{Field, underlier::U1};
18
19pub trait BinaryField:
21 ExtensionField<BinaryField1b> + WithUnderlier<Underlier: UnderlierType>
22{
23 const N_BITS: usize = Self::ORDER_EXPONENT;
24}
25
26macro_rules! binary_field {
28 ($vis:vis $name:ident($typ:ty), $gen:expr) => {
29 #[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, bytemuck::TransparentWrapper)]
30 #[repr(transparent)]
31 $vis struct $name(pub(crate) $typ);
32
33 impl $name {
37 pub const fn val(self) -> $typ {
38 self.0
39 }
40 }
41
42 unsafe impl $crate::underlier::WithUnderlier for $name {
43 type Underlier = $typ;
44 }
45
46 impl Neg for $name {
47 type Output = Self;
48
49 fn neg(self) -> Self::Output {
50 self
51 }
52 }
53
54 impl Add<Self> for $name {
55 type Output = Self;
56
57 #[allow(clippy::suspicious_arithmetic_impl)]
58 fn add(self, rhs: Self) -> Self::Output {
59 $name(self.0 ^ rhs.0)
60 }
61 }
62
63 impl Add<&Self> for $name {
64 type Output = Self;
65
66 #[allow(clippy::suspicious_arithmetic_impl)]
67 fn add(self, rhs: &Self) -> Self::Output {
68 $name(self.0 ^ rhs.0)
69 }
70 }
71
72 impl Sub<Self> for $name {
73 type Output = Self;
74
75 #[allow(clippy::suspicious_arithmetic_impl)]
76 fn sub(self, rhs: Self) -> Self::Output {
77 $name(self.0 ^ rhs.0)
78 }
79 }
80
81 impl Sub<&Self> for $name {
82 type Output = Self;
83
84 #[allow(clippy::suspicious_arithmetic_impl)]
85 fn sub(self, rhs: &Self) -> Self::Output {
86 $name(self.0 ^ rhs.0)
87 }
88 }
89
90 impl Mul<&Self> for $name {
91 type Output = Self;
92
93 fn mul(self, rhs: &Self) -> Self::Output {
94 self * *rhs
95 }
96 }
97
98 impl AddAssign<Self> for $name {
99 fn add_assign(&mut self, rhs: Self) {
100 *self = *self + rhs;
101 }
102 }
103
104 impl AddAssign<&Self> for $name {
105 fn add_assign(&mut self, rhs: &Self) {
106 *self = *self + *rhs;
107 }
108 }
109
110 impl SubAssign<Self> for $name {
111 fn sub_assign(&mut self, rhs: Self) {
112 *self = *self - rhs;
113 }
114 }
115
116 impl SubAssign<&Self> for $name {
117 fn sub_assign(&mut self, rhs: &Self) {
118 *self = *self - *rhs;
119 }
120 }
121
122 impl MulAssign<Self> for $name {
123 fn mul_assign(&mut self, rhs: Self) {
124 *self = *self * rhs;
125 }
126 }
127
128 impl MulAssign<&Self> for $name {
129 fn mul_assign(&mut self, rhs: &Self) {
130 *self = *self * rhs;
131 }
132 }
133
134 impl Sum<Self> for $name {
135 fn sum<I: Iterator<Item=Self>>(iter: I) -> Self {
136 iter.fold(Self::ZERO, |acc, x| acc + x)
137 }
138 }
139
140 impl<'a> Sum<&'a Self> for $name {
141 fn sum<I: Iterator<Item=&'a Self>>(iter: I) -> Self {
142 iter.fold(Self::ZERO, |acc, x| acc + x)
143 }
144 }
145
146 impl Product<Self> for $name {
147 fn product<I: Iterator<Item=Self>>(iter: I) -> Self {
148 iter.fold(Self::ONE, |acc, x| acc * x)
149 }
150 }
151
152 impl<'a> Product<&'a Self> for $name {
153 fn product<I: Iterator<Item=&'a Self>>(iter: I) -> Self {
154 iter.fold(Self::ONE, |acc, x| acc * x)
155 }
156 }
157
158
159 impl Field for $name {
160 const ZERO: Self = $name(<$typ as $crate::underlier::UnderlierType>::ZERO);
161 const ONE: Self = $name(<$typ as $crate::underlier::UnderlierType>::ONE);
162 const CHARACTERISTIC: usize = 2;
163 const ORDER_EXPONENT: usize = <$typ as $crate::underlier::UnderlierType>::BITS;
164 const MULTIPLICATIVE_GENERATOR: $name = $name($gen);
165
166 fn double(&self) -> Self {
167 Self::ZERO
168 }
169 }
170
171 impl $crate::Divisible<$name> for $name {
174 const LOG_N: usize = 0;
175
176 #[inline]
177 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $name> + Send + Clone {
178 std::iter::once(value)
179 }
180
181 #[inline]
182 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $name> + Send + Clone + '_ {
183 std::iter::once(*value)
184 }
185
186 #[inline]
187 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $name> + Send + Clone + '_ {
188 slice.iter().copied()
189 }
190
191 #[inline]
192 unsafe fn get_unchecked(&self, _index: usize) -> $name {
193 *self
194 }
195
196 #[inline]
197 unsafe fn set_unchecked(&mut self, _index: usize, val: $name) {
198 *self = val;
199 }
200
201 #[inline]
202 fn broadcast(val: $name) -> Self {
203 val
204 }
205
206 #[inline]
207 fn from_iter(mut iter: impl Iterator<Item = $name>) -> Self {
208 iter.next().unwrap_or(Self::ZERO)
209 }
210 }
211
212 impl $crate::Maskable<$name> for $name {
216 type Mask = $typ;
217
218 #[inline]
219 fn make_mask(mut selectors: impl Iterator<Item = bool>) -> $typ {
220 <$typ as $crate::underlier::UnderlierType>::fill_with_bit(
221 u8::from(selectors.next().unwrap_or(false)),
222 )
223 }
224
225 #[inline]
226 fn select(&self, mask: &$typ) -> Self {
227 Self(self.0 & *mask)
228 }
229 }
230
231 impl ::rand::distr::Distribution<$name> for ::rand::distr::StandardUniform {
232 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
233 $name(::rand::distr::StandardUniform.sample(rng))
234 }
235 }
236
237 impl Display for $name {
238 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239 write!(f, "0x{repr:0>width$x}", repr=self.val(), width=Self::N_BITS.max(4) / 4)
240 }
241 }
242
243 impl Debug for $name {
244 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
245 let structure_name = std::any::type_name::<$name>().split("::").last().expect("exist");
246
247 write!(f, "{}({})",structure_name, self)
248 }
249 }
250
251 impl BinaryField for $name {}
252
253 impl From<$typ> for $name {
254 fn from(val: $typ) -> Self {
255 return Self(val)
256 }
257 }
258
259 impl From<$name> for $typ {
260 fn from(val: $name) -> Self {
261 return val.0
262 }
263 }
264 }
265}
266
267pub(crate) use binary_field;
268
269macro_rules! mul_by_binary_field_1b {
270 ($name:ident) => {
271 impl Mul<BinaryField1b> for $name {
272 type Output = Self;
273
274 #[inline]
275 #[allow(clippy::suspicious_arithmetic_impl)]
276 fn mul(self, rhs: BinaryField1b) -> Self::Output {
277 use $crate::underlier::{UnderlierType, WithUnderlier};
278
279 $crate::tracing::trace_multiplication!(BinaryField128b, BinaryField1b);
280
281 Self(self.0 & <$name as WithUnderlier>::Underlier::fill_with_bit(u8::from(rhs.0)))
282 }
283 }
284 };
285}
286
287pub(crate) use mul_by_binary_field_1b;
288
289macro_rules! impl_field_extension {
290 ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty)) => {
291 impl TryFrom<$name> for $subfield_name {
292 type Error = ();
293
294 #[inline]
295 fn try_from(elem: $name) -> Result<Self, Self::Error> {
296 use $crate::underlier::NumCast;
297
298 if elem.0 >> $subfield_name::N_BITS
299 == <$typ as $crate::underlier::UnderlierType>::ZERO
300 {
301 Ok($subfield_name(<$subfield_typ>::num_cast_from(elem.val())))
302 } else {
303 Err(())
304 }
305 }
306 }
307
308 impl From<$subfield_name> for $name {
309 #[inline]
310 fn from(elem: $subfield_name) -> Self {
311 $name(<$typ>::from(elem.val()))
312 }
313 }
314
315 impl Add<$subfield_name> for $name {
316 type Output = Self;
317
318 #[inline]
319 fn add(self, rhs: $subfield_name) -> Self::Output {
320 self + Self::from(rhs)
321 }
322 }
323
324 impl Sub<$subfield_name> for $name {
325 type Output = Self;
326
327 #[inline]
328 fn sub(self, rhs: $subfield_name) -> Self::Output {
329 self - Self::from(rhs)
330 }
331 }
332
333 impl AddAssign<$subfield_name> for $name {
334 #[inline]
335 fn add_assign(&mut self, rhs: $subfield_name) {
336 *self = *self + rhs;
337 }
338 }
339
340 impl SubAssign<$subfield_name> for $name {
341 #[inline]
342 fn sub_assign(&mut self, rhs: $subfield_name) {
343 *self = *self - rhs;
344 }
345 }
346
347 impl MulAssign<$subfield_name> for $name {
348 #[inline]
349 fn mul_assign(&mut self, rhs: $subfield_name) {
350 *self = *self * rhs;
351 }
352 }
353
354 impl Add<$name> for $subfield_name {
355 type Output = $name;
356
357 #[inline]
358 fn add(self, rhs: $name) -> Self::Output {
359 rhs + self
360 }
361 }
362
363 impl Sub<$name> for $subfield_name {
364 type Output = $name;
365
366 #[allow(clippy::suspicious_arithmetic_impl)]
367 #[inline]
368 fn sub(self, rhs: $name) -> Self::Output {
369 rhs + self
370 }
371 }
372
373 impl Mul<$name> for $subfield_name {
374 type Output = $name;
375
376 #[inline]
377 fn mul(self, rhs: $name) -> Self::Output {
378 rhs * self
379 }
380 }
381
382 impl ExtensionField<$subfield_name> for $name {
383 const LOG_DEGREE: usize = $log_degree;
384
385 #[inline]
386 fn basis(i: usize) -> Self {
387 use $crate::underlier::UnderlierType;
388
389 assert!(
390 i < 1 << $log_degree,
391 "index {} out of range for degree {}",
392 i,
393 1 << $log_degree
394 );
395 Self(<$typ>::ONE << (i * $subfield_name::N_BITS))
396 }
397
398 #[inline]
399 fn from_bases_sparse(
400 base_elems: impl IntoIterator<Item = $subfield_name>,
401 log_stride: usize,
402 ) -> Self {
403 use $crate::underlier::UnderlierType;
404
405 debug_assert!($name::N_BITS.is_power_of_two());
406 let shift_step = ($subfield_name::N_BITS << log_stride) & ($name::N_BITS - 1);
407 let mut value = <$typ>::ZERO;
408 let mut shift = 0;
409
410 for elem in base_elems.into_iter() {
411 assert!(shift < $name::N_BITS, "too many base elements for extension degree");
412 value |= <$typ>::from(elem.val()) << shift;
413 shift += shift_step;
414 }
415
416 Self(value)
417 }
418
419 #[inline]
420 fn iter_bases(&self) -> impl Iterator<Item = $subfield_name> {
421 use binius_utils::iter::IterExtensions;
422 use $crate::underlier::{Divisible, WithUnderlier};
423
424 Divisible::<<$subfield_name as WithUnderlier>::Underlier>::ref_iter(&self.0)
425 .map_skippable($subfield_name::from)
426 }
427
428 #[inline]
429 fn into_iter_bases(self) -> impl Iterator<Item = $subfield_name> {
430 use binius_utils::iter::IterExtensions;
431 use $crate::underlier::{Divisible, WithUnderlier};
432
433 Divisible::<<$subfield_name as WithUnderlier>::Underlier>::value_iter(self.0)
434 .map_skippable($subfield_name::from)
435 }
436
437 #[inline]
438 unsafe fn get_base_unchecked(&self, i: usize) -> $subfield_name {
439 use $crate::underlier::{Divisible, WithUnderlier};
440 unsafe {
442 $subfield_name::from_underlier(Divisible::<
443 <$subfield_name as WithUnderlier>::Underlier,
444 >::get_unchecked(&self.to_underlier(), i))
445 }
446 }
447
448 #[inline]
449 fn square_transpose(values: &mut [Self]) {
450 crate::transpose::square_transforms_extension_field::<$subfield_name, Self>(values)
451 }
452 }
453 };
454}
455
456pub(crate) use impl_field_extension;
457
458binary_field!(pub BinaryField1b(U1), U1::new(0x1));
459
460crate::arithmetic_traits::impl_trivial_wide_mul!(BinaryField1b);
461
462macro_rules! serialize_deserialize {
463 ($bin_type:ty) => {
464 impl SerializeBytes for $bin_type {
465 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
466 self.0.serialize(write_buf)
467 }
468 }
469
470 impl DeserializeBytes for $bin_type {
471 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
472 Ok(Self(DeserializeBytes::deserialize(read_buf)?))
473 }
474 }
475 };
476}
477
478serialize_deserialize!(BinaryField1b);
479
480impl FixedSizeSerializeBytes for BinaryField1b {
481 const BYTE_SIZE: usize = 1;
482}
483
484impl BinaryField1b {
485 pub const fn new(value: U1) -> Self {
486 Self(value)
487 }
488
489 #[inline]
494 pub unsafe fn new_unchecked(val: u8) -> Self {
495 debug_assert!(val < 2, "val has to be less than 2, but it's {val}");
496
497 Self::new(U1::new_unchecked(val))
498 }
499}
500
501impl From<u8> for BinaryField1b {
502 #[inline]
503 fn from(val: u8) -> Self {
504 Self::new(U1::new(val))
505 }
506}
507
508impl From<BinaryField1b> for u8 {
509 #[inline]
510 fn from(value: BinaryField1b) -> Self {
511 value.val().into()
512 }
513}
514
515impl From<bool> for BinaryField1b {
516 #[inline]
517 fn from(value: bool) -> Self {
518 Self::from(U1::new_unchecked(value.into()))
519 }
520}
521
522#[cfg(test)]
523pub(crate) mod tests {
524 use binius_utils::{DeserializeBytes, SerializeBytes, bytes::BytesMut};
525 use proptest::prelude::*;
526
527 use super::BinaryField1b as BF1;
528 use crate::{
529 AESTowerField8b, BinaryField, BinaryField1b, BinaryField128bGhash, Field,
530 arithmetic_traits::InvertOrZero,
531 };
532
533 #[test]
534 fn test_gf2_add() {
535 assert_eq!(BF1::from(0) + BF1::from(0), BF1::from(0));
536 assert_eq!(BF1::from(0) + BF1::from(1), BF1::from(1));
537 assert_eq!(BF1::from(1) + BF1::from(0), BF1::from(1));
538 assert_eq!(BF1::from(1) + BF1::from(1), BF1::from(0));
539 }
540
541 #[test]
542 fn test_gf2_sub() {
543 assert_eq!(BF1::from(0) - BF1::from(0), BF1::from(0));
544 assert_eq!(BF1::from(0) - BF1::from(1), BF1::from(1));
545 assert_eq!(BF1::from(1) - BF1::from(0), BF1::from(1));
546 assert_eq!(BF1::from(1) - BF1::from(1), BF1::from(0));
547 }
548
549 #[test]
550 fn test_gf2_mul() {
551 assert_eq!(BF1::from(0) * BF1::from(0), BF1::from(0));
552 assert_eq!(BF1::from(0) * BF1::from(1), BF1::from(0));
553 assert_eq!(BF1::from(1) * BF1::from(0), BF1::from(0));
554 assert_eq!(BF1::from(1) * BF1::from(1), BF1::from(1));
555 }
556
557 pub(crate) fn is_binary_field_valid_generator<F: BinaryField>() -> bool {
558 let mut order = if F::N_BITS == 128 {
560 u128::MAX
561 } else {
562 (1 << F::N_BITS) - 1
563 };
564
565 let mut factorization = Vec::new();
567
568 let mut prime = 2;
569 while prime * prime <= order {
570 while order.is_multiple_of(prime) {
571 order /= prime;
572 factorization.push(prime);
573 }
574
575 prime += if prime > 2 { 2 } else { 1 };
576 }
577
578 if order > 1 {
579 factorization.push(order);
580 }
581
582 for mask in 0..(1 << factorization.len()) {
584 let mut divisor = 1;
585
586 for (bit_index, &prime) in factorization.iter().enumerate() {
587 if (1 << bit_index) & mask != 0 {
588 divisor *= prime;
589 }
590 }
591
592 divisor = divisor.reverse_bits();
594
595 let mut pow_divisor = F::ONE;
596 while divisor > 0 {
597 pow_divisor *= pow_divisor;
598
599 if divisor & 1 != 0 {
600 pow_divisor *= F::MULTIPLICATIVE_GENERATOR;
601 }
602
603 divisor >>= 1;
604 }
605
606 let is_root_of_unity = pow_divisor == F::ONE;
608 let is_full_group = mask + 1 == 1 << factorization.len();
609
610 if is_root_of_unity && !is_full_group || !is_root_of_unity && is_full_group {
611 return false;
612 }
613 }
614
615 true
616 }
617
618 #[test]
619 fn test_multiplicative_generators() {
620 assert!(is_binary_field_valid_generator::<BinaryField1b>());
621 assert!(is_binary_field_valid_generator::<AESTowerField8b>());
622 assert!(is_binary_field_valid_generator::<BinaryField128bGhash>());
623 }
624
625 #[test]
626 fn test_field_degrees() {
627 assert_eq!(BinaryField1b::N_BITS, 1);
628 assert_eq!(AESTowerField8b::N_BITS, 8);
629 assert_eq!(BinaryField128bGhash::N_BITS, 128);
630 }
631
632 #[test]
633 fn test_field_formatting() {
634 assert_eq!(format!("{}", BinaryField1b::from(1)), "0x1");
635 assert_eq!(format!("{}", AESTowerField8b::from(3)), "0x03");
636 assert_eq!(
637 format!("{}", BinaryField128bGhash::new(5)),
638 "0x00000000000000000000000000000005"
639 );
640 }
641
642 #[test]
643 fn test_inverse_on_zero() {
644 assert!(BinaryField1b::ZERO.invert_or_zero().is_zero());
645 assert!(AESTowerField8b::ZERO.invert_or_zero().is_zero());
646 assert!(BinaryField128bGhash::ZERO.invert_or_zero().is_zero());
647 }
648
649 proptest! {
650 #[test]
651 fn test_inverse_8b(val in 1u8..) {
652 let x = AESTowerField8b::new(val);
653 let x_inverse = unsafe { x.invert() };
655 assert_eq!(x * x_inverse, AESTowerField8b::ONE);
656 }
657
658 #[test]
659 fn test_inverse_128b(val in 1u128..) {
660 let x = BinaryField128bGhash::from(val);
661 let x_inverse = unsafe { x.invert() };
663 assert_eq!(x * x_inverse, BinaryField128bGhash::ONE);
664 }
665 }
666
667 #[test]
668 fn test_serialization() {
669 let mut buffer = BytesMut::new();
670 let b1 = BinaryField1b::from(0x1);
671 let b8 = AESTowerField8b::new(0x12);
672 let b128 = BinaryField128bGhash::new(0x147AD0369CF258BE8899AABBCCDDEEFF);
673
674 b1.serialize(&mut buffer).unwrap();
675 b8.serialize(&mut buffer).unwrap();
676 b128.serialize(&mut buffer).unwrap();
677
678 let mut read_buffer = buffer.freeze();
679
680 assert_eq!(BinaryField1b::deserialize(&mut read_buffer).unwrap(), b1);
681 assert_eq!(AESTowerField8b::deserialize(&mut read_buffer).unwrap(), b8);
682 assert_eq!(BinaryField128bGhash::deserialize(&mut read_buffer).unwrap(), b128);
683 }
684
685 #[test]
686 fn test_gf2_new_unchecked() {
687 for i in 0..2 {
688 assert_eq!(unsafe { BF1::new_unchecked(i) }, BF1::from(i));
689 }
690 }
691}