Skip to main content

binius_field/arch/
strategies.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{
5	array,
6	iter::Sum,
7	marker::PhantomData,
8	ops::{Add, AddAssign, Sub, SubAssign},
9};
10
11use bytemuck::TransparentWrapper;
12
13use crate::{
14	BinaryField,
15	arch::PackedPrimitiveType,
16	arithmetic_traits::{InvertOrZero, Square, WideMul},
17	underlier::{Divisible, UnderlierType},
18};
19
20/// Pairwise strategy. Apply the result of the operation to each packed element independently.
21pub struct PairwiseStrategy;
22
23/// Strategy that splits the underlier into `SubU`-sized lanes, applies the sub-packing
24/// `PackedPrimitiveType<SubU, F>`'s op to each lane, and recombines — a generic fallback for
25/// packings that lack a specialized full-width [`Square`], [`InvertOrZero`], or [`WideMul`]. The
26/// sub-underlier `SubU` is a `PhantomData` parameter so the packing type `T` stays last for the
27/// macro's `Divide<SubU, $name, N>` form.
28///
29/// `N` is the lane count: callers always pass `N = <U as Divisible<SubU>>::N` (or the literal it
30/// works out to). `Square`/`InvertOrZero` stream through [`Divisible`] and ignore `N`, but it is
31/// still required so every `Divide` instantiation names its lane count explicitly. `WideMul` must
32/// defer reduction, so it materializes one unreduced product per lane in an `N`-element
33/// [`LaneWideProduct`] — and an associated const can't be an array length without
34/// `generic_const_exprs`, which is why `N` is a const generic rather than read from `Divisible`.
35#[repr(transparent)]
36#[derive(TransparentWrapper)]
37#[transparent(T)]
38pub struct Divide<SubU, T, const N: usize>(T, PhantomData<SubU>);
39
40impl<U, SubU, F, const N: usize> Square for Divide<SubU, PackedPrimitiveType<U, F>, N>
41where
42	U: UnderlierType + Divisible<SubU>,
43	SubU: UnderlierType,
44	F: BinaryField,
45	PackedPrimitiveType<SubU, F>: Square,
46{
47	#[inline]
48	fn square(self) -> Self {
49		let val = Self::peel(self);
50		let squared = Divisible::<SubU>::value_iter(val.to_underlier()).map(|lane| {
51			PackedPrimitiveType::<SubU, F>::from_underlier(lane)
52				.square()
53				.to_underlier()
54		});
55		Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(squared)))
56	}
57}
58
59impl<U, SubU, F, const N: usize> InvertOrZero for Divide<SubU, PackedPrimitiveType<U, F>, N>
60where
61	U: UnderlierType + Divisible<SubU>,
62	SubU: UnderlierType,
63	F: BinaryField,
64	PackedPrimitiveType<SubU, F>: InvertOrZero,
65{
66	#[inline]
67	fn invert_or_zero(self) -> Self {
68		let val = Self::peel(self);
69		let inverted = Divisible::<SubU>::value_iter(val.to_underlier()).map(|lane| {
70			PackedPrimitiveType::<SubU, F>::from_underlier(lane)
71				.invert_or_zero()
72				.to_underlier()
73		});
74		Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(inverted)))
75	}
76}
77
78/// One independent deferred wide product per `SubU` lane of a [`Divide`] widening multiply. Lanes
79/// accumulate (`Add`/`Sub`/`Sum`) and reduce independently, mirroring the packing structure, so a
80/// sum of products is reduced only once per lane. `N` is the lane count.
81#[derive(Clone, Copy, Debug)]
82pub struct LaneWideProduct<O, const N: usize>(pub [O; N]);
83
84impl<O: Copy + Default, const N: usize> Default for LaneWideProduct<O, N> {
85	#[inline]
86	fn default() -> Self {
87		Self([O::default(); N])
88	}
89}
90
91impl<O: Copy + Add<Output = O>, const N: usize> Add for LaneWideProduct<O, N> {
92	type Output = Self;
93
94	#[inline]
95	fn add(self, rhs: Self) -> Self {
96		Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
97	}
98}
99
100impl<O: Copy + Add<Output = O>, const N: usize> AddAssign for LaneWideProduct<O, N> {
101	#[inline]
102	fn add_assign(&mut self, rhs: Self) {
103		*self = *self + rhs;
104	}
105}
106
107impl<O: Copy + Sub<Output = O>, const N: usize> Sub for LaneWideProduct<O, N> {
108	type Output = Self;
109
110	#[inline]
111	fn sub(self, rhs: Self) -> Self {
112		Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
113	}
114}
115
116impl<O: Copy + Sub<Output = O>, const N: usize> SubAssign for LaneWideProduct<O, N> {
117	#[inline]
118	fn sub_assign(&mut self, rhs: Self) {
119		*self = *self - rhs;
120	}
121}
122
123impl<O: Copy + Default + Add<Output = O>, const N: usize> Sum for LaneWideProduct<O, N> {
124	#[inline]
125	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
126		iter.fold(Self::default(), |acc, x| acc + x)
127	}
128}
129
130impl<U, SubU, F, const N: usize> WideMul for Divide<SubU, PackedPrimitiveType<U, F>, N>
131where
132	U: UnderlierType + Divisible<SubU>,
133	SubU: UnderlierType,
134	F: BinaryField,
135	PackedPrimitiveType<SubU, F>: WideMul,
136	<PackedPrimitiveType<SubU, F> as WideMul>::Output: Copy + Default,
137{
138	type Output = LaneWideProduct<<PackedPrimitiveType<SubU, F> as WideMul>::Output, N>;
139
140	#[inline]
141	fn wide_mul(a: Self, b: Self) -> Self::Output {
142		debug_assert_eq!(N, <U as Divisible<SubU>>::N, "N must equal Divisible<SubU>::N");
143
144		let a = Self::peel(a).to_underlier();
145		let b = Self::peel(b).to_underlier();
146
147		let mut lanes = [<PackedPrimitiveType<SubU, F> as WideMul>::Output::default(); N];
148		for (slot, (lhs, rhs)) in lanes
149			.iter_mut()
150			.zip(Divisible::<SubU>::value_iter(a).zip(Divisible::<SubU>::value_iter(b)))
151		{
152			*slot = <PackedPrimitiveType<SubU, F> as WideMul>::wide_mul(
153				PackedPrimitiveType::from_underlier(lhs),
154				PackedPrimitiveType::from_underlier(rhs),
155			);
156		}
157		LaneWideProduct(lanes)
158	}
159
160	#[inline]
161	fn reduce(wide: Self::Output) -> Self {
162		let lanes = wide.0.into_iter().map(|product| {
163			<PackedPrimitiveType<SubU, F> as WideMul>::reduce(product).to_underlier()
164		});
165		Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(lanes)))
166	}
167}
168
169/// Wrapper that defines multiplication as `reduce(wide_mul(a, b))`, deferring to the type's own
170/// [`WideMul`] impl, making the widening multiply the single source of truth for both `Mul` and
171/// `WideMul`. Used by every GHASH and AES packing.
172#[repr(transparent)]
173#[derive(TransparentWrapper)]
174pub struct MulFromWideMul<T>(T);
175
176impl<P: WideMul> std::ops::Mul for MulFromWideMul<P> {
177	type Output = Self;
178
179	#[inline]
180	fn mul(self, rhs: Self) -> Self {
181		Self::wrap(P::reduce(P::wide_mul(Self::peel(self), Self::peel(rhs))))
182	}
183}