1use std::iter;
4
5use binius_maybe_rayon::prelude::*;
6use binius_utils::checked_arithmetics::checked_int_div;
7
8use crate::{packed::get_packed_slice_unchecked, ExtensionField, Field, PackedField};
9
10pub fn inner_product_unchecked<F, FE>(
12 a: impl IntoIterator<Item = FE>,
13 b: impl IntoIterator<Item = F>,
14) -> FE
15where
16 F: Field,
17 FE: ExtensionField<F>,
18{
19 iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum()
20}
21
22pub fn inner_product_par<FX, PX, PY>(xs: &[PX], ys: &[PY]) -> FX
25where
26 PX: PackedField<Scalar = FX>,
27 PY: PackedField,
28 FX: ExtensionField<PY::Scalar>,
29{
30 assert!(
31 PX::WIDTH * xs.len() <= PY::WIDTH * ys.len(),
32 "Y elements has to be at least as wide as X elements"
33 );
34
35 if PX::WIDTH * xs.len() < PY::WIDTH * ys.len() {
38 return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys));
39 }
40
41 let calc_product_by_ys = |xs: &[PX], ys: &[PY]| {
42 let mut result = FX::ZERO;
43
44 for (j, y) in ys.iter().enumerate() {
45 for (k, y) in y.iter().enumerate() {
46 result += unsafe { get_packed_slice_unchecked(xs, j * PY::WIDTH + k) } * y
47 }
48 }
49
50 result
51 };
52
53 const CHUNK_SIZE: usize = 64;
57 if ys.len() < 16 * CHUNK_SIZE {
58 calc_product_by_ys(xs, ys)
59 } else {
60 ys.par_chunks(CHUNK_SIZE)
62 .enumerate()
63 .map(|(i, ys)| {
64 let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH);
65 calc_product_by_ys(&xs[offset..], ys)
66 })
67 .sum()
68 }
69}
70
71#[inline(always)]
73pub fn eq<F: Field>(x: F, y: F) -> F {
74 if F::CHARACTERISTIC == 2 {
75 x + y + F::ONE
77 } else {
78 x * y + (F::ONE - x) * (F::ONE - y)
79 }
80}
81
82pub fn powers<F: Field>(val: F) -> impl Iterator<Item = F> {
84 iter::successors(Some(F::ONE), move |&power| Some(power * val))
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::PackedBinaryField4x32b;
91
92 type P = PackedBinaryField4x32b;
93 type F = <P as PackedField>::Scalar;
94
95 #[test]
96 fn test_inner_product_par_equal_length() {
97 let xs1 = F::new(1);
99 let xs2 = F::new(2);
100 let xs = vec![P::set_single(xs1), P::set_single(xs2)];
101 let ys1 = F::new(3);
102 let ys2 = F::new(4);
103 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
104
105 let result = inner_product_par::<F, P, P>(&xs, &ys);
106 let expected = xs1 * ys1 + xs2 * ys2;
107
108 assert_eq!(result, expected);
109 }
110
111 #[test]
112 fn test_inner_product_par_unequal_length() {
113 let xs1 = F::new(1);
115 let xs = vec![P::set_single(xs1)];
116 let ys1 = F::new(2);
117 let ys2 = F::new(3);
118 let ys = vec![P::set_single(ys1), P::set_single(ys2)];
119
120 let result = inner_product_par::<F, P, P>(&xs, &ys);
121 let expected = xs1 * ys1;
122
123 assert_eq!(result, expected);
124 }
125
126 #[test]
127 fn test_inner_product_par_large_input_single_threaded() {
128 let size = 256;
130 let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
131 let ys: Vec<P> = (0..size)
132 .map(|i| P::set_single(F::new((i + 1) as u32)))
133 .collect();
134
135 let result = inner_product_par::<F, P, P>(&xs, &ys);
136
137 let expected = (0..size)
138 .map(|i| F::new(i as u32) * F::new((i + 1) as u32))
139 .sum::<F>();
140
141 assert_eq!(result, expected);
142 }
143
144 #[test]
145 fn test_inner_product_par_large_input_par() {
146 let size = 2000;
148 let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
149 let ys: Vec<P> = (0..size)
150 .map(|i| P::set_single(F::new((i + 1) as u32)))
151 .collect();
152
153 let result = inner_product_par::<F, P, P>(&xs, &ys);
154
155 let expected = (0..size)
156 .map(|i| F::new(i as u32) * F::new((i + 1) as u32))
157 .sum::<F>();
158
159 assert_eq!(result, expected);
160 }
161
162 #[test]
163 fn test_inner_product_par_empty() {
164 let xs: Vec<P> = vec![];
166 let ys: Vec<P> = vec![];
167
168 let result = inner_product_par::<F, P, P>(&xs, &ys);
169 let expected = F::ZERO;
170
171 assert_eq!(result, expected);
172 }
173}