1use std::{
5 array,
6 iter::Sum,
7 marker::PhantomData,
8 ops::{Add, AddAssign, Sub, SubAssign},
9};
10
11use bytemuck::TransparentWrapper;
12
13use crate::{
14 BinaryField,
15 arch::PackedPrimitiveType,
16 arithmetic_traits::{InvertOrZero, Square, WideMul},
17 underlier::{Divisible, UnderlierType},
18};
19
20pub struct PairwiseStrategy;
22
23#[repr(transparent)]
36#[derive(TransparentWrapper)]
37#[transparent(T)]
38pub struct Divide<SubU, T, const N: usize>(T, PhantomData<SubU>);
39
40impl<U, SubU, F, const N: usize> Square for Divide<SubU, PackedPrimitiveType<U, F>, N>
41where
42 U: UnderlierType + Divisible<SubU>,
43 SubU: UnderlierType,
44 F: BinaryField,
45 PackedPrimitiveType<SubU, F>: Square,
46{
47 #[inline]
48 fn square(self) -> Self {
49 let val = Self::peel(self);
50 let squared = Divisible::<SubU>::value_iter(val.to_underlier()).map(|lane| {
51 PackedPrimitiveType::<SubU, F>::from_underlier(lane)
52 .square()
53 .to_underlier()
54 });
55 Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(squared)))
56 }
57}
58
59impl<U, SubU, F, const N: usize> InvertOrZero for Divide<SubU, PackedPrimitiveType<U, F>, N>
60where
61 U: UnderlierType + Divisible<SubU>,
62 SubU: UnderlierType,
63 F: BinaryField,
64 PackedPrimitiveType<SubU, F>: InvertOrZero,
65{
66 #[inline]
67 fn invert_or_zero(self) -> Self {
68 let val = Self::peel(self);
69 let inverted = Divisible::<SubU>::value_iter(val.to_underlier()).map(|lane| {
70 PackedPrimitiveType::<SubU, F>::from_underlier(lane)
71 .invert_or_zero()
72 .to_underlier()
73 });
74 Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(inverted)))
75 }
76}
77
78#[derive(Clone, Copy, Debug)]
82pub struct LaneWideProduct<O, const N: usize>(pub [O; N]);
83
84impl<O: Copy + Default, const N: usize> Default for LaneWideProduct<O, N> {
85 #[inline]
86 fn default() -> Self {
87 Self([O::default(); N])
88 }
89}
90
91impl<O: Copy + Add<Output = O>, const N: usize> Add for LaneWideProduct<O, N> {
92 type Output = Self;
93
94 #[inline]
95 fn add(self, rhs: Self) -> Self {
96 Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
97 }
98}
99
100impl<O: Copy + Add<Output = O>, const N: usize> AddAssign for LaneWideProduct<O, N> {
101 #[inline]
102 fn add_assign(&mut self, rhs: Self) {
103 *self = *self + rhs;
104 }
105}
106
107impl<O: Copy + Sub<Output = O>, const N: usize> Sub for LaneWideProduct<O, N> {
108 type Output = Self;
109
110 #[inline]
111 fn sub(self, rhs: Self) -> Self {
112 Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
113 }
114}
115
116impl<O: Copy + Sub<Output = O>, const N: usize> SubAssign for LaneWideProduct<O, N> {
117 #[inline]
118 fn sub_assign(&mut self, rhs: Self) {
119 *self = *self - rhs;
120 }
121}
122
123impl<O: Copy + Default + Add<Output = O>, const N: usize> Sum for LaneWideProduct<O, N> {
124 #[inline]
125 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
126 iter.fold(Self::default(), |acc, x| acc + x)
127 }
128}
129
130impl<U, SubU, F, const N: usize> WideMul for Divide<SubU, PackedPrimitiveType<U, F>, N>
131where
132 U: UnderlierType + Divisible<SubU>,
133 SubU: UnderlierType,
134 F: BinaryField,
135 PackedPrimitiveType<SubU, F>: WideMul,
136 <PackedPrimitiveType<SubU, F> as WideMul>::Output: Copy + Default,
137{
138 type Output = LaneWideProduct<<PackedPrimitiveType<SubU, F> as WideMul>::Output, N>;
139
140 #[inline]
141 fn wide_mul(a: Self, b: Self) -> Self::Output {
142 debug_assert_eq!(N, <U as Divisible<SubU>>::N, "N must equal Divisible<SubU>::N");
143
144 let a = Self::peel(a).to_underlier();
145 let b = Self::peel(b).to_underlier();
146
147 let mut lanes = [<PackedPrimitiveType<SubU, F> as WideMul>::Output::default(); N];
148 for (slot, (lhs, rhs)) in lanes
149 .iter_mut()
150 .zip(Divisible::<SubU>::value_iter(a).zip(Divisible::<SubU>::value_iter(b)))
151 {
152 *slot = <PackedPrimitiveType<SubU, F> as WideMul>::wide_mul(
153 PackedPrimitiveType::from_underlier(lhs),
154 PackedPrimitiveType::from_underlier(rhs),
155 );
156 }
157 LaneWideProduct(lanes)
158 }
159
160 #[inline]
161 fn reduce(wide: Self::Output) -> Self {
162 let lanes = wide.0.into_iter().map(|product| {
163 <PackedPrimitiveType<SubU, F> as WideMul>::reduce(product).to_underlier()
164 });
165 Self::wrap(PackedPrimitiveType::from_underlier(Divisible::<SubU>::from_iter(lanes)))
166 }
167}
168
169#[repr(transparent)]
173#[derive(TransparentWrapper)]
174pub struct MulFromWideMul<T>(T);
175
176impl<P: WideMul> std::ops::Mul for MulFromWideMul<P> {
177 type Output = Self;
178
179 #[inline]
180 fn mul(self, rhs: Self) -> Self {
181 Self::wrap(P::reduce(P::wide_mul(Self::peel(self), Self::peel(rhs))))
182 }
183}