Skip to main content

binius_math/
inner_product.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, ops::Deref};
4
5use binius_field::{ExtensionField, Field, FieldOps, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::FieldBuffer;
9
10#[inline]
11pub fn inner_product<F: Field>(
12	a: impl IntoIterator<Item = F>,
13	b: impl IntoIterator<Item = F>,
14) -> F {
15	inner_product_subfield(a, b)
16}
17
18#[inline]
19pub fn inner_product_scalars<F: FieldOps>(
20	a: impl IntoIterator<Item = F>,
21	b: impl IntoIterator<Item = F>,
22) -> F {
23	itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
24}
25
26#[inline]
27pub fn inner_product_subfield<F, FSub>(
28	a: impl IntoIterator<Item = FSub>,
29	b: impl IntoIterator<Item = F>,
30) -> F
31where
32	F: Field + ExtensionField<FSub>,
33	FSub: Field,
34{
35	itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
36}
37
38#[inline]
39pub fn inner_product_par<F, P, DataA, DataB>(
40	a: &FieldBuffer<P, DataA>,
41	b: &FieldBuffer<P, DataB>,
42) -> F
43where
44	F: Field,
45	P: PackedField<Scalar = F>,
46	DataA: Deref<Target = [P]>,
47	DataB: Deref<Target = [P]>,
48{
49	let n = a.len();
50	a.as_ref()
51		.par_iter()
52		.zip_eq(b.as_ref().par_iter())
53		.map(|(&a_i, &b_i)| a_i * b_i)
54		.sum::<P>()
55		.into_iter()
56		.take(n)
57		.sum()
58}
59
60#[inline]
61pub fn inner_product_buffers<F, P, DataA, DataB>(
62	a: &FieldBuffer<P, DataA>,
63	b: &FieldBuffer<P, DataB>,
64) -> F
65where
66	F: Field,
67	P: PackedField<Scalar = F>,
68	DataA: Deref<Target = [P]>,
69	DataB: Deref<Target = [P]>,
70{
71	let log_n = a.log_len();
72	inner_product_packed(log_n, a.as_ref().iter().copied(), b.as_ref().iter().copied())
73}
74
75/// Compute the inner product of two scalar sequences generated by iterators of packed elements.
76///
77/// ## Preconditions
78///
79/// * `a` and `b` have length `1 << log_n.saturating_sub(P::LOG_WIDTH)`
80#[inline]
81pub fn inner_product_packed<F, P>(
82	log_n: usize,
83	a: impl ExactSizeIterator<Item = P>,
84	b: impl ExactSizeIterator<Item = P>,
85) -> F
86where
87	F: Field,
88	P: PackedField<Scalar = F>,
89{
90	assert_eq!(a.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); // pre-condition
91	assert_eq!(b.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH)); // pre-condition
92
93	iter::zip(a, b)
94		.map(|(a_i, b_i)| a_i * b_i)
95		.sum::<P>()
96		.into_iter()
97		.take(1 << log_n)
98		.sum()
99}
100
101#[cfg(test)]
102mod tests {
103	use binius_field::{PackedBinaryGhash4x128b, Random};
104	use rand::{SeedableRng, rngs::StdRng};
105
106	use super::*;
107
108	#[test]
109	fn test_inner_product_packing_width_greater_than_buffer_length() {
110		type P = PackedBinaryGhash4x128b;
111
112		let mut rng = StdRng::seed_from_u64(0);
113
114		// Create buffers with log_len = 0 (1 element), but packing width = 4
115		let packed_a = P::random(&mut rng);
116		let packed_b = P::random(&mut rng);
117
118		let buffer_a = FieldBuffer::new(0, vec![packed_a]);
119		let buffer_b = FieldBuffer::new(0, vec![packed_b]);
120
121		// Compute inner product using both functions
122		let result_par = inner_product_par(&buffer_a, &buffer_b);
123		let result_packed = inner_product_buffers(&buffer_a, &buffer_b);
124
125		// Compute expected result manually - only first element should be used
126		let expected = buffer_a.get(0) * buffer_b.get(0);
127
128		assert_eq!(result_par, expected, "inner_product_par failed for log_len=0");
129		assert_eq!(result_packed, expected, "inner_product_packed failed for log_len=0");
130	}
131}