Skip to main content

binius_field/arch/portable/
scaled_arithmetic.rs

1// Copyright 2025-2026 The Binius Developers
2
3use std::array;
4
5use bytemuck::{Pod, TransparentWrapper};
6
7use super::packed::PackedPrimitiveType;
8use crate::{
9	BinaryField,
10	arch::LaneWideProduct,
11	arithmetic_traits::{InvertOrZero, Square, WideMul},
12	underlier::{ScaledUnderlier, UnderlierType},
13};
14
15/// Wrapper for `ScaledUnderlier` multiplication that delegates to sub-underlier operations.
16#[repr(transparent)]
17#[derive(TransparentWrapper)]
18pub struct Scaled<T>(T);
19
20impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> std::ops::Mul
21	for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
22where
23	PackedPrimitiveType<U, Scalar>: std::ops::Mul<Output = PackedPrimitiveType<U, Scalar>>,
24{
25	type Output = Self;
26
27	fn mul(self, rhs: Self) -> Self {
28		let (a, b) = (Self::peel(self), Self::peel(rhs));
29		Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(array::from_fn(|i| {
30			let lhs_i = a.0.0[i];
31			let rhs_i = b.0.0[i];
32			PackedPrimitiveType::peel(
33				PackedPrimitiveType::wrap(lhs_i) * PackedPrimitiveType::wrap(rhs_i),
34			)
35		}))))
36	}
37}
38
39impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> Square
40	for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
41where
42	PackedPrimitiveType<U, Scalar>: Square,
43{
44	fn square(self) -> Self {
45		let val = Self::peel(self);
46		Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(val.0.0.map(|sub_underlier| {
47			PackedPrimitiveType::peel(Square::square(PackedPrimitiveType::wrap(sub_underlier)))
48		}))))
49	}
50}
51
52impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> InvertOrZero
53	for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
54where
55	PackedPrimitiveType<U, Scalar>: InvertOrZero,
56{
57	fn invert_or_zero(self) -> Self {
58		let val = Self::peel(self);
59		Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(val.0.0.map(|sub_underlier| {
60			PackedPrimitiveType::peel(InvertOrZero::invert_or_zero(PackedPrimitiveType::wrap(
61				sub_underlier,
62			)))
63		}))))
64	}
65}
66
67/// Widening multiply for a `ScaledUnderlier` packing: apply the sub-underlier packing's [`WideMul`]
68/// to each of the `N` lanes independently, deferring reduction per lane via [`LaneWideProduct`].
69/// The `Scaled` analogue of [`Divide`](crate::arch::Divide)'s `WideMul`, but addressing the inner
70/// sub-underliers of `ScaledUnderlier` directly instead of splitting an underlier with `Divisible`.
71impl<U: UnderlierType + Pod, Scalar: BinaryField, const N: usize> WideMul
72	for Scaled<PackedPrimitiveType<ScaledUnderlier<U, N>, Scalar>>
73where
74	PackedPrimitiveType<U, Scalar>: WideMul,
75	<PackedPrimitiveType<U, Scalar> as WideMul>::Output: Copy + Default,
76{
77	type Output = LaneWideProduct<<PackedPrimitiveType<U, Scalar> as WideMul>::Output, N>;
78
79	#[inline]
80	fn wide_mul(a: Self, b: Self) -> Self::Output {
81		let (a, b) = (Self::peel(a), Self::peel(b));
82		LaneWideProduct(array::from_fn(|i| {
83			<PackedPrimitiveType<U, Scalar> as WideMul>::wide_mul(
84				PackedPrimitiveType::wrap(a.0.0[i]),
85				PackedPrimitiveType::wrap(b.0.0[i]),
86			)
87		}))
88	}
89
90	#[inline]
91	fn reduce(wide: Self::Output) -> Self {
92		Self::wrap(PackedPrimitiveType::wrap(ScaledUnderlier(array::from_fn(|i| {
93			PackedPrimitiveType::peel(<PackedPrimitiveType<U, Scalar> as WideMul>::reduce(
94				wide.0[i],
95			))
96		}))))
97	}
98}
99
100#[cfg(test)]
101mod tests {
102	use proptest::prelude::*;
103
104	use super::*;
105	use crate::{aes_field::AESTowerField8b, arch::M128};
106
107	// A two-lane `ScaledUnderlier` AES packing whose `M128` lanes carry their own `WideMul`.
108	type Inner = PackedPrimitiveType<ScaledUnderlier<M128, 2>, AESTowerField8b>;
109	type P = Scaled<Inner>;
110
111	fn packing(lo: u128, hi: u128) -> P {
112		P::wrap(Inner::from_underlier(ScaledUnderlier([M128::from_u128(lo), M128::from_u128(hi)])))
113	}
114
115	proptest! {
116		// `reduce(wide_mul(a, b))` must agree with the `Scaled` multiply.
117		#[test]
118		fn wide_mul_reduce_matches_mul(
119			a_lo in any::<u128>(), a_hi in any::<u128>(),
120			b_lo in any::<u128>(), b_hi in any::<u128>(),
121		) {
122			let (a, b) = (packing(a_lo, a_hi), packing(b_lo, b_hi));
123			let via_wide = P::peel(P::reduce(P::wide_mul(a, b)));
124			let via_mul = P::peel(packing(a_lo, a_hi) * packing(b_lo, b_hi));
125			prop_assert_eq!(via_wide, via_mul);
126		}
127
128		// Deferred per-lane accumulation: summing two wide products then reducing once must equal
129		// the sum of the two reduced products.
130		#[test]
131		fn wide_mul_accumulates(
132			a_lo in any::<u128>(), a_hi in any::<u128>(),
133			b_lo in any::<u128>(), b_hi in any::<u128>(),
134			c_lo in any::<u128>(), c_hi in any::<u128>(),
135			d_lo in any::<u128>(), d_hi in any::<u128>(),
136		) {
137			let acc = P::wide_mul(packing(a_lo, a_hi), packing(b_lo, b_hi))
138				+ P::wide_mul(packing(c_lo, c_hi), packing(d_lo, d_hi));
139			let via_wide = P::peel(P::reduce(acc));
140			let via_mul = P::peel(packing(a_lo, a_hi) * packing(b_lo, b_hi))
141				+ P::peel(packing(c_lo, c_hi) * packing(d_lo, d_hi));
142			prop_assert_eq!(via_wide, via_mul);
143		}
144	}
145}