binius_field/
packed_extension_ops.rs1use binius_utils::rayon::prelude::{
4 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) {
10 ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
11}
12
13pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
14 lhs: &mut [PE],
15 rhs: &[PE::PackedSubfield],
16) {
17 ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
18 PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
19 })
20}
21
22pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
26 packed_subfields: &[PE::PackedSubfield],
27 i: usize,
28) -> PE::PackedSubfield {
29 let bottom_most_scalar_idx = i * PE::WIDTH;
30 let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
31 let bottom_most_scalar_idx_within_packed_subfield =
32 bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
33 let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
34
35 unsafe {
36 packed_subfields
37 .get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
38 .spread_unchecked(PE::LOG_WIDTH, block_idx)
39 }
40}
41
42pub fn ext_base_op<PE, F, Func>(lhs: &mut [PE], rhs: &[PE::PackedSubfield], op: Func)
58where
59 PE: PackedExtension<F>,
60 F: Field,
61 Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
62{
63 assert!(
64 lhs.len() == rhs.len() * PE::Scalar::DEGREE,
65 "lhs.len() ({}) must equal rhs.len() * PE::Scalar::DEGREE ({})",
66 lhs.len(),
67 rhs.len() * PE::Scalar::DEGREE
68 );
69
70 lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76 });
77}
78
79pub fn ext_base_op_par<PE, F, Func>(lhs: &mut [PE], rhs: &[PE::PackedSubfield], op: Func)
86where
87 PE: PackedExtension<F>,
88 F: Field,
89 Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
90{
91 assert!(
92 lhs.len() == rhs.len() * PE::Scalar::DEGREE,
93 "lhs.len() ({}) must equal rhs.len() * PE::Scalar::DEGREE ({})",
94 lhs.len(),
95 rhs.len() * PE::Scalar::DEGREE
96 );
97
98 lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
99 let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
102
103 *lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
104 });
105}