1use std::iter;
6
7use binius_field::Field;
8
9pub fn expand_subset_sums_array<F: Field, const N: usize, const N_EXP2: usize>(
44 elems: [F; N],
45) -> [F; N_EXP2] {
46 assert_eq!(N_EXP2, 1 << N);
47
48 let mut expanded = [F::ZERO; N_EXP2];
49 for (i, elem_i) in elems.into_iter().enumerate() {
50 let span = &mut expanded[..1 << (i + 1)];
51 let (lo_half, hi_half) = span.split_at_mut(1 << i);
52 for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
53 *hi_half_i = *lo_half_i + elem_i;
54 }
55 }
56 expanded
57}
58
59pub fn expand_subset_sums<F: Field>(elems: &[F]) -> Vec<F> {
84 let n = elems.len();
85 let n_exp2 = 1 << n;
86
87 let mut expanded = vec![F::ZERO; n_exp2];
88 for (i, &elem_i) in elems.iter().enumerate() {
89 let span = &mut expanded[..1 << (i + 1)];
90 let (lo_half, hi_half) = span.split_at_mut(1 << i);
91 for (lo_half_i, hi_half_i) in iter::zip(lo_half, hi_half) {
92 *hi_half_i = *lo_half_i + elem_i;
93 }
94 }
95 expanded
96}
97
98#[cfg(test)]
99mod tests {
100 use binius_field::{BinaryField128bGhash, Field, Random};
101 use proptest::prelude::*;
102 use rand::{SeedableRng, rngs::StdRng};
103
104 use super::*;
105
106 type F = BinaryField128bGhash;
107
108 proptest! {
109 #[test]
110 fn test_expand_subset_sums_correctness(
111 n in 0usize..=8, index in 0usize..256, ) {
114 prop_assume!(index < (1 << n));
116
117 let mut rng = StdRng::seed_from_u64(n as u64);
118
119 let elems: Vec<F> = (0..n).map(|_| F::random(&mut rng)).collect();
121
122 let result = expand_subset_sums(&elems);
124
125 prop_assert_eq!(result.len(), 1 << n);
127
128 let mut expected = F::ZERO;
130 for (bit_pos, &elem) in elems.iter().enumerate() {
131 if (index >> bit_pos) & 1 == 1 {
132 expected += elem;
133 }
134 }
135
136 prop_assert_eq!(
137 result[index],
138 expected,
139 "Index {} should have subset sum corresponding to its binary representation",
140 index
141 );
142 }
143 }
144
145 #[test]
146 fn test_expand_subset_sums_array_slice_consistency() {
147 let mut rng = StdRng::seed_from_u64(0);
148
149 fn check_consistency<const N: usize, const N_EXP2: usize>(elems_vec: &[F]) {
151 assert_eq!(elems_vec.len(), N);
152 assert_eq!(N_EXP2, 1 << N);
153
154 let mut elems_array = [F::ZERO; N];
155 elems_array.copy_from_slice(elems_vec);
156
157 let result_array = expand_subset_sums_array::<_, N, N_EXP2>(elems_array);
158 let result_slice = expand_subset_sums(elems_vec);
159 assert_eq!(result_array.as_ref(), result_slice.as_slice());
160 }
161
162 for n in 0..=4 {
164 let elems_vec: Vec<F> = (0..n).map(|_| F::random(&mut rng)).collect();
166
167 match n {
169 0 => check_consistency::<0, 1>(&elems_vec),
170 1 => check_consistency::<1, 2>(&elems_vec),
171 2 => check_consistency::<2, 4>(&elems_vec),
172 3 => check_consistency::<3, 8>(&elems_vec),
173 4 => check_consistency::<4, 16>(&elems_vec),
174 _ => unreachable!("n is constrained to 0..=4"),
175 }
176 }
177 }
178}