use std::iter;
use binius_utils::checked_arithmetics::checked_int_div;
use rayon::prelude::*;
use crate::{packed::get_packed_slice_unchecked, ExtensionField, Field, PackedField};
pub fn inner_product_unchecked<F, FE>(
a: impl IntoIterator<Item = FE>,
b: impl IntoIterator<Item = F>,
) -> FE
where
F: Field,
FE: ExtensionField<F>,
{
iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum::<FE>()
}
pub fn inner_product_par<FX, PX, PY>(xs: &[PX], ys: &[PY]) -> FX
where
PX: PackedField<Scalar = FX>,
PY: PackedField,
FX: ExtensionField<PY::Scalar>,
{
assert!(
PX::WIDTH * xs.len() <= PY::WIDTH * ys.len(),
"Y elements has to be at least as wide as X elements"
);
if PX::WIDTH * xs.len() < PY::WIDTH * ys.len() {
return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys));
}
let calc_product_by_ys = |x_offset, ys: &[PY]| {
let mut result = FX::ZERO;
let xs = &xs[x_offset..];
for (j, y) in ys.iter().enumerate() {
for (k, y) in y.iter().enumerate() {
result += unsafe { get_packed_slice_unchecked(xs, j * PY::WIDTH + k) } * y
}
}
result
};
const CHUNK_SIZE: usize = 64;
if ys.len() < 16 * CHUNK_SIZE {
calc_product_by_ys(0, ys)
} else {
ys.par_chunks(CHUNK_SIZE)
.enumerate()
.map(|(i, ys)| {
let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH);
calc_product_by_ys(offset, ys)
})
.sum()
}
}
#[inline(always)]
pub fn eq<F: Field>(x: F, y: F) -> F {
x * y + (F::ONE - x) * (F::ONE - y)
}
pub fn powers<F: Field>(val: F) -> impl Iterator<Item = F> {
iter::successors(Some(F::ONE), move |&power| Some(power * val))
}