binius_math/
span.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! Utilities for computing subset sums and vector span operations.
4
5use std::iter;
6
7use binius_field::Field;
8
9/// Expands an array of field elements into all possible subset sums.
10///
11/// For an input array `[a, b, c]`, this computes all possible sums of subsets:
12/// `[0, a, b, a+b, c, a+c, b+c, a+b+c]`
13///
14/// This is used to create lookup tables for the Method of Four Russians optimization,
15/// where we precompute all possible combinations of a small set of values to avoid
16/// doing the additions at runtime.
17///
18/// ## Type Parameters
19///
20/// * `F` - The field element type
21/// * `N` - Size of the input array
22/// * `N_EXP2` - Size of the output array, must be 2^N
23///
24/// ## Arguments
25///
26/// * `elems` - Input array of N field elements
27///
28/// ## Returns
29///
30/// An array of size N_EXP2 containing all possible subset sums of the input elements
31///
32/// ## Preconditions
33///
34/// * N_EXP2 must equal 2^N
35///
36/// ## Example
37///
38/// ```ignore
39/// let input = [F::ONE, F::from(2)];
40/// let sums = expand_subset_sums_array(input);
41/// // sums = [F::ZERO, F::ONE, F::from(2), F::from(3)]
42/// ```
43pub fn expand_subset_sums_array<F: Field, const N: usize, const N_EXP2: usize>(
44	elems: [F; N],
45) -> [F; N_EXP2] {
46	assert_eq!(N_EXP2, 1 << N);
47
48	let mut expanded = [F::ZERO; N_EXP2];
49	for (i, elem_i) in elems.into_iter().enumerate() {
50		let span = &mut expanded[..1 << (i + 1)];
51		let (lo_half, hi_half) = span.split_at_mut(1 << i);
52		for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
53			*hi_half_i = *lo_half_i + elem_i;
54		}
55	}
56	expanded
57}
58
59/// Expands a slice of field elements into all possible subset sums.
60///
61/// For an input slice `[a, b, c]`, this computes all possible sums of subsets:
62/// `[0, a, b, a+b, c, a+c, b+c, a+b+c]`
63///
64/// This is a dynamic version of [`expand_subset_sums_array`] that works with slices
65/// and returns a Vec with length 2^n where n is the input length.
66///
67/// ## Arguments
68///
69/// * `elems` - Input slice of field elements
70///
71/// ## Returns
72///
73/// A Vec containing all possible subset sums of the input elements, with length 2^n
74/// where n is the length of the input slice.
75///
76/// ## Example
77///
78/// ```ignore
79/// let input = vec![F::ONE, F::from(2)];
80/// let sums = expand_subset_sums(&input);
81/// // sums = vec![F::ZERO, F::ONE, F::from(2), F::from(3)]
82/// ```
83pub fn expand_subset_sums<F: Field>(elems: &[F]) -> Vec<F> {
84	let n = elems.len();
85	let n_exp2 = 1 << n;
86
87	let mut expanded = vec![F::ZERO; n_exp2];
88	for (i, &elem_i) in elems.iter().enumerate() {
89		let span = &mut expanded[..1 << (i + 1)];
90		let (lo_half, hi_half) = span.split_at_mut(1 << i);
91		for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
92			*hi_half_i = *lo_half_i + elem_i;
93		}
94	}
95	expanded
96}
97
98#[cfg(test)]
99mod tests {
100	use binius_field::{BinaryField128bGhash, Field, Random};
101	use proptest::prelude::*;
102	use rand::{SeedableRng, rngs::StdRng};
103
104	use super::*;
105
106	type F = BinaryField128bGhash;
107
108	proptest! {
109		#[test]
110		fn test_expand_subset_sums_correctness(
111			n in 0usize..=8,  // Input length (small to avoid exponential blowup)
112			index in 0usize..256,  // Index to check
113		) {
114			// Filter out invalid indices
115			prop_assume!(index < (1 << n));
116
117			let mut rng = StdRng::seed_from_u64(n as u64);
118
119			// Generate random input elements
120			let elems: Vec<F> = (0..n).map(|_| F::random(&mut rng)).collect();
121
122			// Compute the expansion
123			let result = expand_subset_sums(&elems);
124
125			// Verify the result length
126			prop_assert_eq!(result.len(), 1 << n);
127
128			// Compute expected sum based on binary representation of index
129			let mut expected = F::ZERO;
130			for (bit_pos, &elem) in elems.iter().enumerate() {
131				if (index >> bit_pos) & 1 == 1 {
132					expected += elem;
133				}
134			}
135
136			prop_assert_eq!(
137				result[index],
138				expected,
139				"Index {} should have subset sum corresponding to its binary representation",
140				index
141			);
142		}
143	}
144
145	#[test]
146	fn test_expand_subset_sums_array_slice_consistency() {
147		let mut rng = StdRng::seed_from_u64(0);
148
149		// Helper function to test consistency for a specific size
150		fn check_consistency<const N: usize, const N_EXP2: usize>(elems_vec: &[F]) {
151			assert_eq!(elems_vec.len(), N);
152			assert_eq!(N_EXP2, 1 << N);
153
154			let mut elems_array = [F::ZERO; N];
155			elems_array.copy_from_slice(elems_vec);
156
157			let result_array = expand_subset_sums_array::<_, N, N_EXP2>(elems_array);
158			let result_slice = expand_subset_sums(elems_vec);
159			assert_eq!(result_array.as_ref(), result_slice.as_slice());
160		}
161
162		// Test with different sizes to verify array/slice consistency
163		for n in 0..=4 {
164			// Generate random input elements
165			let elems_vec: Vec<F> = (0..n).map(|_| F::random(&mut rng)).collect();
166
167			// Test with the specific size n
168			match n {
169				0 => check_consistency::<0, 1>(&elems_vec),
170				1 => check_consistency::<1, 2>(&elems_vec),
171				2 => check_consistency::<2, 4>(&elems_vec),
172				3 => check_consistency::<3, 8>(&elems_vec),
173				4 => check_consistency::<4, 16>(&elems_vec),
174				_ => unreachable!("n is constrained to 0..=4"),
175			}
176		}
177	}
178}