binius_math/
inner_product.rs1use std::{iter, ops::Deref};
4
5use binius_field::{ExtensionField, Field, FieldOps, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::FieldBuffer;
9
10#[inline]
11pub fn inner_product<F: Field>(
12 a: impl IntoIterator<Item = F>,
13 b: impl IntoIterator<Item = F>,
14) -> F {
15 inner_product_subfield(a, b)
16}
17
18#[inline]
19pub fn inner_product_scalars<F: FieldOps>(
20 a: impl IntoIterator<Item = F>,
21 b: impl IntoIterator<Item = F>,
22) -> F {
23 itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
24}
25
26#[inline]
27pub fn inner_product_subfield<F, FSub>(
28 a: impl IntoIterator<Item = FSub>,
29 b: impl IntoIterator<Item = F>,
30) -> F
31where
32 F: Field + ExtensionField<FSub>,
33 FSub: Field,
34{
35 itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
36}
37
38#[inline]
39pub fn inner_product_par<F, P, DataA, DataB>(
40 a: &FieldBuffer<P, DataA>,
41 b: &FieldBuffer<P, DataB>,
42) -> F
43where
44 F: Field,
45 P: PackedField<Scalar = F>,
46 DataA: Deref<Target = [P]>,
47 DataB: Deref<Target = [P]>,
48{
49 let n = a.len();
50 a.as_ref()
51 .par_iter()
52 .zip_eq(b.as_ref().par_iter())
53 .map(|(&a_i, &b_i)| a_i * b_i)
54 .sum::<P>()
55 .into_iter()
56 .take(n)
57 .sum()
58}
59
60#[inline]
61pub fn inner_product_buffers<F, P, DataA, DataB>(
62 a: &FieldBuffer<P, DataA>,
63 b: &FieldBuffer<P, DataB>,
64) -> F
65where
66 F: Field,
67 P: PackedField<Scalar = F>,
68 DataA: Deref<Target = [P]>,
69 DataB: Deref<Target = [P]>,
70{
71 let log_n = a.log_len();
72 inner_product_packed(log_n, a.as_ref().iter().copied(), b.as_ref().iter().copied())
73}
74
75#[inline]
81pub fn inner_product_packed<F, P>(
82 log_n: usize,
83 a: impl ExactSizeIterator<Item = P>,
84 b: impl ExactSizeIterator<Item = P>,
85) -> F
86where
87 F: Field,
88 P: PackedField<Scalar = F>,
89{
90 assert_eq!(a.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); assert_eq!(b.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); iter::zip(a, b)
94 .map(|(a_i, b_i)| a_i * b_i)
95 .sum::<P>()
96 .into_iter()
97 .take(1 << log_n)
98 .sum()
99}
100
101#[cfg(test)]
102mod tests {
103 use binius_field::{PackedBinaryGhash4x128b, Random};
104 use rand::{SeedableRng, rngs::StdRng};
105
106 use super::*;
107
108 #[test]
109 fn test_inner_product_packing_width_greater_than_buffer_length() {
110 type P = PackedBinaryGhash4x128b;
111
112 let mut rng = StdRng::seed_from_u64(0);
113
114 let packed_a = P::random(&mut rng);
116 let packed_b = P::random(&mut rng);
117
118 let buffer_a = FieldBuffer::new(0, vec![packed_a]);
119 let buffer_b = FieldBuffer::new(0, vec![packed_b]);
120
121 let result_par = inner_product_par(&buffer_a, &buffer_b);
123 let result_packed = inner_product_buffers(&buffer_a, &buffer_b);
124
125 let expected = buffer_a.get(0) * buffer_b.get(0);
127
128 assert_eq!(result_par, expected, "inner_product_par failed for log_len=0");
129 assert_eq!(result_packed, expected, "inner_product_packed failed for log_len=0");
130 }
131}