1use binius_maybe_rayon::prelude::{
4 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10 lhs: &mut [PE],
11 rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13 ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17 lhs: &mut [PE],
18 rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20 ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21 PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22 })
23}
24
25pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29 packed_subfields: &[PE::PackedSubfield],
30 i: usize,
31) -> PE::PackedSubfield {
32 let bottom_most_scalar_idx = i * PE::WIDTH;
33 let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34 let bottom_most_scalar_idx_within_packed_subfield =
35 bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36 let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38 packed_subfields
39 .get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
40 .spread_unchecked(PE::LOG_WIDTH, block_idx)
41}
42
43pub fn ext_base_op<PE, F, Func>(
55 lhs: &mut [PE],
56 rhs: &[PE::PackedSubfield],
57 op: Func,
58) -> Result<(), Error>
59where
60 PE: PackedExtension<F>,
61 F: Field,
62 Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
63{
64 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
65 return Err(Error::MismatchedLengths);
66 }
67
68 lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
69 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
71
72 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
73 });
74 Ok(())
75}
76
77pub fn ext_base_op_par<PE, F, Func>(
80 lhs: &mut [PE],
81 rhs: &[PE::PackedSubfield],
82 op: Func,
83) -> Result<(), Error>
84where
85 PE: PackedExtension<F>,
86 F: Field,
87 Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
88{
89 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
90 return Err(Error::MismatchedLengths);
91 }
92
93 lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
94 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
96
97 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
98 });
99
100 Ok(())
101}
102
103#[cfg(test)]
104mod tests {
105 use proptest::prelude::*;
106
107 use crate::{
108 ext_base_mul, ext_base_mul_par,
109 packed::{get_packed_slice, pack_slice},
110 underlier::WithUnderlier,
111 BinaryField128b, BinaryField16b, BinaryField8b, PackedBinaryField16x16b,
112 PackedBinaryField2x128b, PackedBinaryField32x8b,
113 };
114
115 fn strategy_8b_scalars() -> impl Strategy<Value = [BinaryField8b; 32]> {
116 any::<[<BinaryField8b as WithUnderlier>::Underlier; 32]>()
117 .prop_map(|arr| arr.map(<BinaryField8b>::from_underlier))
118 }
119
120 fn strategy_16b_scalars() -> impl Strategy<Value = [BinaryField16b; 32]> {
121 any::<[<BinaryField16b as WithUnderlier>::Underlier; 32]>()
122 .prop_map(|arr| arr.map(<BinaryField16b>::from_underlier))
123 }
124
125 fn strategy_128b_scalars() -> impl Strategy<Value = [BinaryField128b; 32]> {
126 any::<[<BinaryField128b as WithUnderlier>::Underlier; 32]>()
127 .prop_map(|arr| arr.map(<BinaryField128b>::from_underlier))
128 }
129
130 proptest! {
131 #[test]
132 fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
133 let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
134 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
135
136 ext_base_mul(&mut ext_packed, &base_packed).unwrap();
137
138 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
139 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
140 }
141 }
142
143 #[test]
144 fn test_base_ext_mul_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
145 let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
146 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
147
148 ext_base_mul(&mut ext_packed, &base_packed).unwrap();
149
150 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
151 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
152 }
153 }
154
155
156 #[test]
157 fn test_base_ext_mul_par_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
158 let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
159 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
160
161 ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
162
163 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
164 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
165 }
166 }
167
168 #[test]
169 fn test_base_ext_mul_par_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
170 let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
171 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
172
173 ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
174
175 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
176 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
177 }
178 }
179 }
180}