binius_field/arch/portable/
scaled_arithmetic.rs1use std::array;
4
5use bytemuck::{Pod, TransparentWrapper};
6
7use super::packed::PackedPrimitiveType;
8use crate::{
9 BinaryField,
10 arch::LaneWideProduct,
11 arithmetic_traits::{InvertOrZero, Square, WideMul},
12 underlier::{ScaledUnderlier, UnderlierType},
13};
14
15#[repr(transparent)]
17#[derive(TransparentWrapper)]
18pub struct Scaled<T>(T);
19
20impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> std::ops::Mul
21 for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
22where
23 PackedPrimitiveType<U, Scalar>: std::ops::Mul<Output = PackedPrimitiveType<U, Scalar>>,
24{
25 type Output = Self;
26
27 fn mul(self, rhs: Self) -> Self {
28 let (a, b) = (Self::peel(self), Self::peel(rhs));
29 Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(array::from_fn(|i| {
30 let lhs_i = a.0.0[i];
31 let rhs_i = b.0.0[i];
32 PackedPrimitiveType::peel(
33 PackedPrimitiveType::wrap(lhs_i) * PackedPrimitiveType::wrap(rhs_i),
34 )
35 }))))
36 }
37}
38
39impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> Square
40 for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
41where
42 PackedPrimitiveType<U, Scalar>: Square,
43{
44 fn square(self) -> Self {
45 let val = Self::peel(self);
46 Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(val.0.0.map(|sub_underlier| {
47 PackedPrimitiveType::peel(Square::square(PackedPrimitiveType::wrap(sub_underlier)))
48 }))))
49 }
50}
51
52impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> InvertOrZero
53 for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
54where
55 PackedPrimitiveType<U, Scalar>: InvertOrZero,
56{
57 fn invert_or_zero(self) -> Self {
58 let val = Self::peel(self);
59 Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(val.0.0.map(|sub_underlier| {
60 PackedPrimitiveType::peel(InvertOrZero::invert_or_zero(PackedPrimitiveType::wrap(
61 sub_underlier,
62 )))
63 }))))
64 }
65}
66
67impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> WideMul
72 for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
73where
74 PackedPrimitiveType<U, Scalar>: WideMul,
75 <PackedPrimitiveType<U, Scalar> as WideMul>::Output: Copy + Default,
76{
77 type Output = LaneWideProduct<<PackedPrimitiveType<U, Scalar> as WideMul>::Output, N>;
78
79 #[inline]
80 fn wide_mul(a: Self, b: Self) -> Self::Output {
81 let (a, b) = (Self::peel(a), Self::peel(b));
82 LaneWideProduct(array::from_fn(|i| {
83 <PackedPrimitiveType<U, Scalar> as WideMul>::wide_mul(
84 PackedPrimitiveType::wrap(a.0.0[i]),
85 PackedPrimitiveType::wrap(b.0.0[i]),
86 )
87 }))
88 }
89
90 #[inline]
91 fn reduce(wide: Self::Output) -> Self {
92 Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(array::from_fn(|i| {
93 PackedPrimitiveType::peel(<PackedPrimitiveType<U, Scalar> as WideMul>::reduce(
94 wide.0[i],
95 ))
96 }))))
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use proptest::prelude::*;
103
104 use super::*;
105 use crate::{aes_field::AESTowerField8b, arch::M128};
106
107 type Inner = PackedPrimitiveType<ScaledUnderlier<M128, 2>, AESTowerField8b>;
109 type P = Scaled<Inner>;
110
111 fn packing(lo: u128, hi: u128) -> P {
112 P::wrap(Inner::from_underlier(ScaledUnderlier([M128::from_u128(lo), M128::from_u128(hi)])))
113 }
114
115 proptest! {
116 #[test]
118 fn wide_mul_reduce_matches_mul(
119 a_lo in any::<u128>(), a_hi in any::<u128>(),
120 b_lo in any::<u128>(), b_hi in any::<u128>(),
121 ) {
122 let (a, b) = (packing(a_lo, a_hi), packing(b_lo, b_hi));
123 let via_wide = P::peel(P::reduce(P::wide_mul(a, b)));
124 let via_mul = P::peel(packing(a_lo, a_hi) * packing(b_lo, b_hi));
125 prop_assert_eq!(via_wide, via_mul);
126 }
127
128 #[test]
131 fn wide_mul_accumulates(
132 a_lo in any::<u128>(), a_hi in any::<u128>(),
133 b_lo in any::<u128>(), b_hi in any::<u128>(),
134 c_lo in any::<u128>(), c_hi in any::<u128>(),
135 d_lo in any::<u128>(), d_hi in any::<u128>(),
136 ) {
137 let acc = P::wide_mul(packing(a_lo, a_hi), packing(b_lo, b_hi))
138 + P::wide_mul(packing(c_lo, c_hi), packing(d_lo, d_hi));
139 let via_wide = P::peel(P::reduce(acc));
140 let via_mul = P::peel(packing(a_lo, a_hi) * packing(b_lo, b_hi))
141 + P::peel(packing(c_lo, c_hi) * packing(d_lo, d_hi));
142 prop_assert_eq!(via_wide, via_mul);
143 }
144 }
145}