binius_math/multilinear/
fold.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::{Error, FieldBuffer};
9
10pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
15 values: &mut FieldBuffer<P, Data>,
16 scalar: P::Scalar,
17) -> Result<(), Error> {
18 let broadcast_scalar = P::broadcast(scalar);
19 {
20 let mut split = values.split_half_mut()?;
21 let (mut lo, mut hi) = split.halves();
22 (lo.as_mut(), hi.as_mut())
23 .into_par_iter()
24 .for_each(|(zero, one)| {
25 *zero += broadcast_scalar * (*one - *zero);
26 });
27 }
28
29 values.truncate(values.log_len() - 1);
30 Ok(())
31}
32
33pub fn binary_fold_high<P, DataOut, DataIn>(
47 values: &mut FieldBuffer<P, DataOut>,
48 tensor: &FieldBuffer<P, DataIn>,
49 bits: impl RandomAccessSequence<bool> + Sync,
50) -> Result<(), Error>
51where
52 P: PackedField,
53 DataOut: DerefMut<Target = [P]>,
54 DataIn: Deref<Target = [P]> + Sync,
55{
56 if !bits.len().is_power_of_two() {
57 return Err(Error::PowerOfTwoLengthRequired);
58 }
59
60 let values_log_len = values.log_len();
61 let width = P::WIDTH.min(values.len());
62
63 if 1 << (values_log_len + tensor.log_len()) != bits.len() {
64 return Err(Error::FoldLengthMismatch);
65 }
66
67 values
68 .as_mut()
69 .iter_mut()
70 .enumerate()
71 .for_each(|(i, packed)| {
72 *packed = P::from_scalars((0..width).map(|j| {
73 let scalar_index = i << P::LOG_WIDTH | j;
74 let mut acc = P::Scalar::ZERO;
75
76 for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
77 for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
78 let tensor_scalar_index = k << P::LOG_WIDTH | l;
79 if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
80 acc += tensor_scalar;
81 }
82 }
83 }
84
85 acc
86 }));
87 });
88
89 Ok(())
90}
91
92#[cfg(test)]
93mod tests {
94 use std::iter::repeat_with;
95
96 use rand::prelude::*;
97
98 use super::*;
99 use crate::{
100 multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
101 test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
102 };
103
104 type P = Packed128b;
105 type F = B128;
106
107 #[test]
108 fn test_fold_highest_var_inplace() {
109 let mut rng = StdRng::seed_from_u64(0);
110
111 let n_vars = 10;
112
113 let point = random_scalars::<F>(&mut rng, n_vars);
114 let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
115
116 let eval = evaluate(&multilinear, &point).unwrap();
117
118 for &scalar in point.iter().rev() {
119 fold_highest_var_inplace(&mut multilinear, scalar).unwrap();
120 }
121
122 assert_eq!(multilinear.get_checked(0).unwrap(), eval);
123 }
124
125 fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
126 n_vars: usize,
127 tensor_n_vars: usize,
128 ) {
129 let mut rng = StdRng::seed_from_u64(0);
130
131 let point = random_scalars::<F>(&mut rng, tensor_n_vars);
132
133 let tensor = eq_ind_partial_eval::<P>(&point);
134
135 let bits = repeat_with(|| rng.random())
136 .take(1 << n_vars)
137 .collect::<Vec<bool>>();
138
139 let bits_scalars = bits
140 .iter()
141 .map(|&b| if b { F::ONE } else { F::ZERO })
142 .collect::<Vec<F>>();
143
144 let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars).unwrap();
145
146 let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
147 binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice()).unwrap();
148
149 for &scalar in point.iter().rev() {
150 fold_highest_var_inplace(&mut bits_buffer, scalar).unwrap();
151 }
152
153 assert_eq!(bits_buffer, binary_fold_result);
154 }
155
156 #[test]
157 fn test_binary_fold_high_conforms_to_regular_fold_high() {
158 for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
159 test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
160 }
161 }
162}