binius_math/
inner_product.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, ops::Deref};
4
5use binius_field::{ExtensionField, Field, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::FieldBuffer;
9
10#[inline]
11pub fn inner_product<F: Field>(
12	a: impl IntoIterator<Item = F>,
13	b: impl IntoIterator<Item = F>,
14) -> F {
15	inner_product_subfield(a, b)
16}
17
18#[inline]
19pub fn inner_product_subfield<F, FSub>(
20	a: impl IntoIterator<Item = FSub>,
21	b: impl IntoIterator<Item = F>,
22) -> F
23where
24	F: Field + ExtensionField<FSub>,
25	FSub: Field,
26{
27	itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
28}
29
30#[inline]
31pub fn inner_product_par<F, P, DataA, DataB>(
32	a: &FieldBuffer<P, DataA>,
33	b: &FieldBuffer<P, DataB>,
34) -> F
35where
36	F: Field,
37	P: PackedField<Scalar = F>,
38	DataA: Deref<Target = [P]>,
39	DataB: Deref<Target = [P]>,
40{
41	let n = a.len();
42	a.as_ref()
43		.par_iter()
44		.zip_eq(b.as_ref().par_iter())
45		.map(|(&a_i, &b_i)| a_i * b_i)
46		.sum::<P>()
47		.into_iter()
48		.take(n)
49		.sum()
50}
51
52#[inline]
53pub fn inner_product_buffers<F, P, DataA, DataB>(
54	a: &FieldBuffer<P, DataA>,
55	b: &FieldBuffer<P, DataB>,
56) -> F
57where
58	F: Field,
59	P: PackedField<Scalar = F>,
60	DataA: Deref<Target = [P]>,
61	DataB: Deref<Target = [P]>,
62{
63	let log_n = a.log_len();
64	inner_product_packed(log_n, a.as_ref().iter().copied(), b.as_ref().iter().copied())
65}
66
67/// Compute the inner product of two scalar sequences generated by iterators of packed elements.
68///
69/// ## Preconditions
70///
71/// * `a` and `b` have length `1 << log_n.saturating_sub(P::LOG_WIDTH)`
72#[inline]
73pub fn inner_product_packed<F, P>(
74	log_n: usize,
75	a: impl ExactSizeIterator<Item = P>,
76	b: impl ExactSizeIterator<Item = P>,
77) -> F
78where
79	F: Field,
80	P: PackedField<Scalar = F>,
81{
82	assert_eq!(a.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); // pre-condition
83	assert_eq!(b.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); // pre-condition
84
85	iter::zip(a, b)
86		.map(|(a_i, b_i)| a_i * b_i)
87		.sum::<P>()
88		.into_iter()
89		.take(1 << log_n)
90		.sum()
91}
92
93#[cfg(test)]
94mod tests {
95	use binius_field::{PackedBinaryGhash4x128b, Random};
96	use rand::{SeedableRng, rngs::StdRng};
97
98	use super::*;
99
100	#[test]
101	fn test_inner_product_packing_width_greater_than_buffer_length() {
102		type P = PackedBinaryGhash4x128b;
103
104		let mut rng = StdRng::seed_from_u64(0);
105
106		// Create buffers with log_len = 0 (1 element), but packing width = 4
107		let packed_a = P::random(&mut rng);
108		let packed_b = P::random(&mut rng);
109
110		let buffer_a = FieldBuffer::new(0, vec![packed_a]).unwrap();
111		let buffer_b = FieldBuffer::new(0, vec![packed_b]).unwrap();
112
113		// Compute inner product using both functions
114		let result_par = inner_product_par(&buffer_a, &buffer_b);
115		let result_packed = inner_product_buffers(&buffer_a, &buffer_b);
116
117		// Compute expected result manually - only first element should be used
118		let expected = buffer_a.get(0).unwrap() * buffer_b.get(0).unwrap();
119
120		assert_eq!(result_par, expected, "inner_product_par failed for log_len=0");
121		assert_eq!(result_packed, expected, "inner_product_packed failed for log_len=0");
122	}
123}