1use binius_maybe_rayon::prelude::{
4 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10 lhs: &mut [PE],
11 rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13 ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17 lhs: &mut [PE],
18 rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20 ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21 PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22 })
23}
24
25pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29 packed_subfields: &[PE::PackedSubfield],
30 i: usize,
31) -> PE::PackedSubfield {
32 let bottom_most_scalar_idx = i * PE::WIDTH;
33 let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34 let bottom_most_scalar_idx_within_packed_subfield =
35 bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36 let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38 unsafe {
39 packed_subfields
40 .get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
41 .spread_unchecked(PE::LOG_WIDTH, block_idx)
42 }
43}
44
45pub fn ext_base_op<PE, F, Func>(
57 lhs: &mut [PE],
58 rhs: &[PE::PackedSubfield],
59 op: Func,
60) -> Result<(), Error>
61where
62 PE: PackedExtension<F>,
63 F: Field,
64 Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
65{
66 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
67 return Err(Error::MismatchedLengths);
68 }
69
70 lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76 });
77 Ok(())
78}
79
80pub fn ext_base_op_par<PE, F, Func>(
83 lhs: &mut [PE],
84 rhs: &[PE::PackedSubfield],
85 op: Func,
86) -> Result<(), Error>
87where
88 PE: PackedExtension<F>,
89 F: Field,
90 Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
91{
92 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
93 return Err(Error::MismatchedLengths);
94 }
95
96 lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
97 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
100
101 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
102 });
103
104 Ok(())
105}
106
107#[cfg(test)]
108mod tests {
109 use proptest::prelude::*;
110
111 use crate::{
112 BinaryField8b, BinaryField16b, BinaryField128b, PackedBinaryField2x128b,
113 PackedBinaryField16x16b, PackedBinaryField32x8b, ext_base_mul, ext_base_mul_par,
114 packed::{get_packed_slice, pack_slice},
115 underlier::WithUnderlier,
116 };
117
118 fn strategy_8b_scalars() -> impl Strategy<Value = [BinaryField8b; 32]> {
119 any::<[<BinaryField8b as WithUnderlier>::Underlier; 32]>()
120 .prop_map(|arr| arr.map(<BinaryField8b>::from_underlier))
121 }
122
123 fn strategy_16b_scalars() -> impl Strategy<Value = [BinaryField16b; 32]> {
124 any::<[<BinaryField16b as WithUnderlier>::Underlier; 32]>()
125 .prop_map(|arr| arr.map(<BinaryField16b>::from_underlier))
126 }
127
128 fn strategy_128b_scalars() -> impl Strategy<Value = [BinaryField128b; 32]> {
129 any::<[<BinaryField128b as WithUnderlier>::Underlier; 32]>()
130 .prop_map(|arr| arr.map(<BinaryField128b>::from_underlier))
131 }
132
133 proptest! {
134 #[test]
135 fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
136 let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
137 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
138
139 ext_base_mul(&mut ext_packed, &base_packed).unwrap();
140
141 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
142 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
143 }
144 }
145
146 #[test]
147 fn test_base_ext_mul_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
148 let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
149 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
150
151 ext_base_mul(&mut ext_packed, &base_packed).unwrap();
152
153 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
154 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
155 }
156 }
157
158
159 #[test]
160 fn test_base_ext_mul_par_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
161 let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
162 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
163
164 ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
165
166 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
167 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
168 }
169 }
170
171 #[test]
172 fn test_base_ext_mul_par_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
173 let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
174 let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
175
176 ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
177
178 for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
179 assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
180 }
181 }
182 }
183}