binius_math/multilinear/
fold.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::FieldBuffer;
9
10pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
19 values: &mut FieldBuffer<P, Data>,
20 scalar: P::Scalar,
21) {
22 let broadcast_scalar = P::broadcast(scalar);
23 {
24 let mut split = values.split_half_mut();
25 let (mut lo, mut hi) = split.halves();
26 (lo.as_mut(), hi.as_mut())
27 .into_par_iter()
28 .for_each(|(zero, one)| {
29 *zero += broadcast_scalar * (*one - *zero);
30 });
31 }
32
33 values.truncate(values.log_len() - 1);
34}
35
36pub fn binary_fold_high<P, DataOut, DataIn>(
50 values: &mut FieldBuffer<P, DataOut>,
51 tensor: &FieldBuffer<P, DataIn>,
52 bits: impl RandomAccessSequence<bool> + Sync,
53) where
54 P: PackedField,
55 DataOut: DerefMut<Target = [P]>,
56 DataIn: Deref<Target = [P]> + Sync,
57{
58 assert!(bits.len().is_power_of_two(), "precondition: bits length must be a power of two");
59
60 let values_log_len = values.log_len();
61 let width = P::WIDTH.min(values.len());
62
63 assert_eq!(
64 1 << (values_log_len + tensor.log_len()),
65 bits.len(),
66 "precondition: bits length must equal values length times tensor length"
67 );
68
69 values
70 .as_mut()
71 .iter_mut()
72 .enumerate()
73 .for_each(|(i, packed)| {
74 *packed = P::from_scalars((0..width).map(|j| {
75 let scalar_index = i << P::LOG_WIDTH | j;
76 let mut acc = P::Scalar::ZERO;
77
78 for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
79 for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
80 let tensor_scalar_index = k << P::LOG_WIDTH | l;
81 if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
82 acc += tensor_scalar;
83 }
84 }
85 }
86
87 acc
88 }));
89 });
90}
91
92#[cfg(test)]
93mod tests {
94 use std::iter::repeat_with;
95
96 use rand::prelude::*;
97
98 use super::*;
99 use crate::{
100 multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
101 test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
102 };
103
104 type P = Packed128b;
105 type F = B128;
106
107 #[test]
108 fn test_fold_highest_var_inplace() {
109 let mut rng = StdRng::seed_from_u64(0);
110
111 let n_vars = 10;
112
113 let point = random_scalars::<F>(&mut rng, n_vars);
114 let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
115
116 let eval = evaluate(&multilinear, &point);
117
118 for &scalar in point.iter().rev() {
119 fold_highest_var_inplace(&mut multilinear, scalar);
120 }
121
122 assert_eq!(multilinear.get(0), eval);
123 }
124
125 fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
126 n_vars: usize,
127 tensor_n_vars: usize,
128 ) {
129 let mut rng = StdRng::seed_from_u64(0);
130
131 let point = random_scalars::<F>(&mut rng, tensor_n_vars);
132
133 let tensor = eq_ind_partial_eval::<P>(&point);
134
135 let bits = repeat_with(|| rng.random())
136 .take(1 << n_vars)
137 .collect::<Vec<bool>>();
138
139 let bits_scalars = bits
140 .iter()
141 .map(|&b| if b { F::ONE } else { F::ZERO })
142 .collect::<Vec<F>>();
143
144 let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars);
145
146 let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
147 binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice());
148
149 for &scalar in point.iter().rev() {
150 fold_highest_var_inplace(&mut bits_buffer, scalar);
151 }
152
153 assert_eq!(bits_buffer, binary_fold_result);
154 }
155
156 #[test]
157 fn test_binary_fold_high_conforms_to_regular_fold_high() {
158 for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
159 test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
160 }
161 }
162}