binius_math/
inner_product.rs1use std::{iter, ops::Deref};
4
5use binius_field::{ExtensionField, Field, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::FieldBuffer;
9
10#[inline]
11pub fn inner_product<F: Field>(
12 a: impl IntoIterator<Item = F>,
13 b: impl IntoIterator<Item = F>,
14) -> F {
15 inner_product_subfield(a, b)
16}
17
18#[inline]
19pub fn inner_product_subfield<F, FSub>(
20 a: impl IntoIterator<Item = FSub>,
21 b: impl IntoIterator<Item = F>,
22) -> F
23where
24 F: Field + ExtensionField<FSub>,
25 FSub: Field,
26{
27 itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
28}
29
30#[inline]
31pub fn inner_product_par<F, P, DataA, DataB>(
32 a: &FieldBuffer<P, DataA>,
33 b: &FieldBuffer<P, DataB>,
34) -> F
35where
36 F: Field,
37 P: PackedField<Scalar = F>,
38 DataA: Deref<Target = [P]>,
39 DataB: Deref<Target = [P]>,
40{
41 let n = a.len();
42 a.as_ref()
43 .par_iter()
44 .zip_eq(b.as_ref().par_iter())
45 .map(|(&a_i, &b_i)| a_i * b_i)
46 .sum::<P>()
47 .into_iter()
48 .take(n)
49 .sum()
50}
51
52#[inline]
53pub fn inner_product_buffers<F, P, DataA, DataB>(
54 a: &FieldBuffer<P, DataA>,
55 b: &FieldBuffer<P, DataB>,
56) -> F
57where
58 F: Field,
59 P: PackedField<Scalar = F>,
60 DataA: Deref<Target = [P]>,
61 DataB: Deref<Target = [P]>,
62{
63 let log_n = a.log_len();
64 inner_product_packed(log_n, a.as_ref().iter().copied(), b.as_ref().iter().copied())
65}
66
67#[inline]
73pub fn inner_product_packed<F, P>(
74 log_n: usize,
75 a: impl ExactSizeIterator<Item = P>,
76 b: impl ExactSizeIterator<Item = P>,
77) -> F
78where
79 F: Field,
80 P: PackedField<Scalar = F>,
81{
82 assert_eq!(a.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); assert_eq!(b.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); iter::zip(a, b)
86 .map(|(a_i, b_i)| a_i * b_i)
87 .sum::<P>()
88 .into_iter()
89 .take(1 << log_n)
90 .sum()
91}
92
93#[cfg(test)]
94mod tests {
95 use binius_field::{PackedBinaryGhash4x128b, Random};
96 use rand::{SeedableRng, rngs::StdRng};
97
98 use super::*;
99
100 #[test]
101 fn test_inner_product_packing_width_greater_than_buffer_length() {
102 type P = PackedBinaryGhash4x128b;
103
104 let mut rng = StdRng::seed_from_u64(0);
105
106 let packed_a = P::random(&mut rng);
108 let packed_b = P::random(&mut rng);
109
110 let buffer_a = FieldBuffer::new(0, vec![packed_a]).unwrap();
111 let buffer_b = FieldBuffer::new(0, vec![packed_b]).unwrap();
112
113 let result_par = inner_product_par(&buffer_a, &buffer_b);
115 let result_packed = inner_product_buffers(&buffer_a, &buffer_b);
116
117 let expected = buffer_a.get(0).unwrap() * buffer_b.get(0).unwrap();
119
120 assert_eq!(result_par, expected, "inner_product_par failed for log_len=0");
121 assert_eq!(result_packed, expected, "inner_product_packed failed for log_len=0");
122 }
123}