1use std::iter;
4
5use binius_maybe_rayon::prelude::*;
6use binius_utils::checked_arithmetics::checked_int_div;
7
8use crate::{ExtensionField, Field, PackedField, packed::get_packed_slice_unchecked};
9
10pub fn inner_product_unchecked<F, FE>(
12 a: impl IntoIterator<Item = FE>,
13 b: impl IntoIterator<Item = F>,
14) -> FE
15where
16 F: Field,
17 FE: ExtensionField<F>,
18{
19 iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum()
20}
21
22pub fn inner_product_par<FX, PX, PY>(xs: &[PX], ys: &[PY]) -> FX
25where
26 PX: PackedField<Scalar = FX>,
27 PY: PackedField,
28 FX: ExtensionField<PY::Scalar>,
29{
30 assert!(
31 PX::WIDTH * xs.len() <= PY::WIDTH * ys.len(),
32 "Y elements has to be at least as wide as X elements"
33 );
34
35 if PX::WIDTH * xs.len() < PY::WIDTH * ys.len() {
38 return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys));
39 }
40
41 let calc_product_by_ys = |xs: &[PX], ys: &[PY]| {
42 let mut result = FX::ZERO;
43
44 for (j, y) in ys.iter().enumerate() {
45 for (k, y) in y.iter().enumerate() {
46 result += unsafe { get_packed_slice_unchecked(xs, j * PY::WIDTH + k) } * y
47 }
48 }
49
50 result
51 };
52
53 const CHUNK_SIZE: usize = 64;
57 if ys.len() < 16 * CHUNK_SIZE {
58 calc_product_by_ys(xs, ys)
59 } else {
60 ys.par_chunks(CHUNK_SIZE)
63 .enumerate()
64 .map(|(i, ys)| {
65 let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH);
66 calc_product_by_ys(&xs[offset..], ys)
67 })
68 .sum()
69 }
70}
71
72#[inline(always)]
74pub fn eq<F: Field>(x: F, y: F) -> F {
75 if F::CHARACTERISTIC == 2 {
76 x + y + F::ONE
78 } else {
79 x * y + (F::ONE - x) * (F::ONE - y)
80 }
81}
82
83pub fn powers<F: Field>(val: F) -> impl Iterator<Item = F> {
85 iter::successors(Some(F::ONE), move |&power| Some(power * val))
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::PackedBinaryField4x32b;
92
93 type P = PackedBinaryField4x32b;
94 type F = <P as PackedField>::Scalar;
95
96 #[test]
97 fn test_inner_product_par_equal_length() {
98 let xs1 = F::new(1);
100 let xs2 = F::new(2);
101 let xs = vec![P::set_single(xs1), P::set_single(xs2)];
102 let ys1 = F::new(3);
103 let ys2 = F::new(4);
104 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
105
106 let result = inner_product_par::<F, P, P>(&xs, &ys);
107 let expected = xs1 * ys1 + xs2 * ys2;
108
109 assert_eq!(result, expected);
110 }
111
112 #[test]
113 fn test_inner_product_par_unequal_length() {
114 let xs1 = F::new(1);
116 let xs = vec![P::set_single(xs1)];
117 let ys1 = F::new(2);
118 let ys2 = F::new(3);
119 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
120
121 let result = inner_product_par::<F, P, P>(&xs, &ys);
122 let expected = xs1 * ys1;
123
124 assert_eq!(result, expected);
125 }
126
127 #[test]
128 fn test_inner_product_par_large_input_single_threaded() {
129 let size = 256;
131 let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
132 let ys: Vec<P> = (0..size)
133 .map(|i| P::set_single(F::new((i + 1) as u32)))
134 .collect();
135
136 let result = inner_product_par::<F, P, P>(&xs, &ys);
137
138 let expected = (0..size)
139 .map(|i| F::new(i as u32) * F::new((i + 1) as u32))
140 .sum::<F>();
141
142 assert_eq!(result, expected);
143 }
144
145 #[test]
146 fn test_inner_product_par_large_input_par() {
147 let size = 2000;
149 let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
150 let ys: Vec<P> = (0..size)
151 .map(|i| P::set_single(F::new((i + 1) as u32)))
152 .collect();
153
154 let result = inner_product_par::<F, P, P>(&xs, &ys);
155
156 let expected = (0..size)
157 .map(|i| F::new(i as u32) * F::new((i + 1) as u32))
158 .sum::<F>();
159
160 assert_eq!(result, expected);
161 }
162
163 #[test]
164 fn test_inner_product_par_empty() {
165 let xs: Vec<P> = vec![];
167 let ys: Vec<P> = vec![];
168
169 let result = inner_product_par::<F, P, P>(&xs, &ys);
170 let expected = F::ZERO;
171
172 assert_eq!(result, expected);
173 }
174}