binius_field/
packed_extension_ops.rs1use binius_utils::rayon::prelude::{
4 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10 lhs: &mut [PE],
11 rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13 ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17 lhs: &mut [PE],
18 rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20 ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21 PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22 })
23}
24
25pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29 packed_subfields: &[PE::PackedSubfield],
30 i: usize,
31) -> PE::PackedSubfield {
32 let bottom_most_scalar_idx = i * PE::WIDTH;
33 let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34 let bottom_most_scalar_idx_within_packed_subfield =
35 bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36 let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38 unsafe {
39 packed_subfields
40 .get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
41 .spread_unchecked(PE::LOG_WIDTH, block_idx)
42 }
43}
44
45pub fn ext_base_op<PE, F, Func>(
57 lhs: &mut [PE],
58 rhs: &[PE::PackedSubfield],
59 op: Func,
60) -> Result<(), Error>
61where
62 PE: PackedExtension<F>,
63 F: Field,
64 Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
65{
66 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
67 return Err(Error::MismatchedLengths);
68 }
69
70 lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76 });
77 Ok(())
78}
79
80pub fn ext_base_op_par<PE, F, Func>(
83 lhs: &mut [PE],
84 rhs: &[PE::PackedSubfield],
85 op: Func,
86) -> Result<(), Error>
87where
88 PE: PackedExtension<F>,
89 F: Field,
90 Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
91{
92 if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
93 return Err(Error::MismatchedLengths);
94 }
95
96 lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
97 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
100
101 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
102 });
103
104 Ok(())
105}