1use std::iter;
4
5use binius_utils::{checked_arithmetics::checked_int_div, rayon::prelude::*};
6
7use crate::{ExtensionField, Field, PackedField, packed::get_packed_slice_unchecked};
8
9#[inline]
11pub fn inner_product_unchecked<F, FE>(
12 a: impl IntoIterator<Item = FE>,
13 b: impl IntoIterator<Item = F>,
14) -> FE
15where
16 F: Field,
17 FE: ExtensionField<F>,
18{
19 iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum()
20}
21
22pub fn inner_product_par<FX, PX, PY>(xs: &[PX], ys: &[PY]) -> FX
25where
26 PX: PackedField<Scalar = FX>,
27 PY: PackedField,
28 FX: ExtensionField<PY::Scalar>,
29{
30 assert!(
31 PX::WIDTH * xs.len() <= PY::WIDTH * ys.len(),
32 "Y elements has to be at least as wide as X elements"
33 );
34
35 if PX::WIDTH * xs.len() < PY::WIDTH * ys.len() {
38 return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys));
39 }
40
41 let calc_product_by_ys = |xs: &[PX], ys: &[PY]| {
42 let mut result = FX::ZERO;
43
44 for (j, y) in ys.iter().enumerate() {
45 for (k, y) in y.iter().enumerate() {
46 result += unsafe { get_packed_slice_unchecked(xs, j * PY::WIDTH + k) } * y
47 }
48 }
49
50 result
51 };
52
53 const CHUNK_SIZE: usize = 64;
57 if ys.len() < 16 * CHUNK_SIZE {
58 calc_product_by_ys(xs, ys)
59 } else {
60 ys.par_chunks(CHUNK_SIZE)
63 .enumerate()
64 .map(|(i, ys)| {
65 let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH);
66 calc_product_by_ys(&xs[offset..], ys)
67 })
68 .sum()
69 }
70}
71
72pub fn powers<F: Field>(val: F) -> impl Iterator<Item = F> {
74 iter::successors(Some(F::ONE), move |&power| Some(power * val))
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::{BinaryField128bGhash, PackedBinaryGhash4x128b};
81
82 type P = PackedBinaryGhash4x128b;
83 type F = BinaryField128bGhash;
84
85 #[test]
86 fn test_inner_product_par_equal_length() {
87 let xs1 = F::new(1);
89 let xs2 = F::new(2);
90 let xs = vec![P::set_single(xs1), P::set_single(xs2)];
91 let ys1 = F::new(3);
92 let ys2 = F::new(4);
93 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
94
95 let result = inner_product_par::<F, P, P>(&xs, &ys);
96 let expected = xs1 * ys1 + xs2 * ys2;
97
98 assert_eq!(result, expected);
99 }
100
101 #[test]
102 fn test_inner_product_par_unequal_length() {
103 let xs1 = F::new(1);
105 let xs = vec![P::set_single(xs1)];
106 let ys1 = F::new(2);
107 let ys2 = F::new(3);
108 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
109
110 let result = inner_product_par::<F, P, P>(&xs, &ys);
111 let expected = xs1 * ys1;
112
113 assert_eq!(result, expected);
114 }
115
116 #[test]
117 fn test_inner_product_par_large_input_single_threaded() {
118 let size = 256;
120 let xs: Vec<P> = (0..size)
121 .map(|i| P::set_single(F::new(i as u128)))
122 .collect();
123 let ys: Vec<P> = (0..size)
124 .map(|i| P::set_single(F::new((i + 1) as u128)))
125 .collect();
126
127 let result = inner_product_par::<F, P, P>(&xs, &ys);
128
129 let expected = (0..size)
130 .map(|i| F::new(i as u128) * F::new((i + 1) as u128))
131 .sum::<F>();
132
133 assert_eq!(result, expected);
134 }
135
136 #[test]
137 fn test_inner_product_par_large_input_par() {
138 let size = 2000;
140 let xs: Vec<P> = (0..size)
141 .map(|i| P::set_single(F::new(i as u128)))
142 .collect();
143 let ys: Vec<P> = (0..size)
144 .map(|i| P::set_single(F::new((i + 1) as u128)))
145 .collect();
146
147 let result = inner_product_par::<F, P, P>(&xs, &ys);
148
149 let expected = (0..size)
150 .map(|i| F::new(i as u128) * F::new((i + 1) as u128))
151 .sum::<F>();
152
153 assert_eq!(result, expected);
154 }
155
156 #[test]
157 fn test_inner_product_par_empty() {
158 let xs: Vec<P> = vec![];
160 let ys: Vec<P> = vec![];
161
162 let result = inner_product_par::<F, P, P>(&xs, &ys);
163 let expected = F::ZERO;
164
165 assert_eq!(result, expected);
166 }
167}