Skip to main content

binius_field/arch/x86_64/simd/
simd_arithmetic.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{BinaryField, aes_field::AESTowerField8b, underlier::UnderlierWithBitOps};
6
7#[allow(dead_code)]
8pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
9	/// Blend odd and even elements
10	fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
11	/// Set alpha to even elements
12	fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
13	/// Apply `mask` to `a` (set zeros at positions where high bit of the `mask` is 0).
14	fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
15
16	/// Bit xor operation
17	fn xor(a: Self, b: Self) -> Self;
18
19	/// Shuffle 8-bit elements within 128-bit lanes
20	fn shuffle_epi8(a: Self, b: Self) -> Self;
21
22	/// Byte shifts within 128-bit lanes
23	fn bslli_epi128<const IMM8: i32>(self) -> Self;
24	fn bsrli_epi128<const IMM8: i32>(self) -> Self;
25
26	/// Initialize value with a single element
27	fn set1_epi128(val: __m128i) -> Self;
28	fn set_epi_64(val: i64) -> Self;
29
30	#[inline(always)]
31	fn dup_shuffle<Scalar: BinaryField>() -> Self {
32		let shuffle_mask_128 = unsafe {
33			match Scalar::N_BITS.ilog2() {
34				3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
35				4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
36				5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
37				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
38				_ => panic!("unsupported bit count"),
39			}
40		};
41
42		Self::set1_epi128(shuffle_mask_128)
43	}
44
45	#[inline(always)]
46	fn flip_shuffle<Scalar: BinaryField>() -> Self {
47		let flip_mask_128 = unsafe {
48			match Scalar::N_BITS.ilog2() {
49				3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
50				4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
51				5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
52				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
53				_ => panic!("unsupported bit count"),
54			}
55		};
56
57		Self::set1_epi128(flip_mask_128)
58	}
59
60	/// Creates mask to propagate the highest bit form mask to other element bytes
61	#[inline(always)]
62	fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
63		let epi8_mask_128 = unsafe {
64			match Scalar::N_BITS.ilog2() {
65				4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
66				5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
67				6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
68				7 => _mm_set1_epi8(15),
69				_ => panic!("unsupported bit count"),
70			}
71		};
72
73		Self::set1_epi128(epi8_mask_128)
74	}
75
76	#[inline(always)]
77	fn alpha<Scalar: BinaryField>() -> Self {
78		let alpha_128 = {
79			match Scalar::N_BITS.ilog2() {
80				3 => {
81					// Compiler will optimize this if out for each instantiation
82					let type_id = TypeId::of::<Scalar>();
83					let value = if type_id == TypeId::of::<AESTowerField8b>() {
84						0xd3u8 as i8
85					} else {
86						panic!("tower field not supported")
87					};
88					unsafe { _mm_set1_epi8(value) }
89				}
90				4 => unsafe { _mm_set1_epi16(0x0100) },
91				5 => unsafe { _mm_set1_epi32(0x00010000) },
92				6 => unsafe { _mm_set1_epi64x(0x0000000100000000) },
93				_ => panic!("unsupported bit count"),
94			}
95		};
96
97		Self::set1_epi128(alpha_128)
98	}
99
100	#[inline(always)]
101	fn even_mask<Scalar: BinaryField>() -> Self {
102		let mask_128 = {
103			match Scalar::N_BITS.ilog2() {
104				3 => unsafe { _mm_set1_epi16(0x00FF) },
105				4 => unsafe { _mm_set1_epi32(0x0000FFFF) },
106				5 => unsafe { _mm_set1_epi64x(0x00000000FFFFFFFF) },
107				6 => unsafe { _mm_set_epi64x(0, -1) },
108				_ => panic!("unsupported bit count"),
109			}
110		};
111
112		Self::set1_epi128(mask_128)
113	}
114}