binius_math/multilinear/
fold.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::{Error, FieldBuffer};
9
10pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
15 values: &mut FieldBuffer<P, Data>,
16 scalar: P::Scalar,
17) -> Result<(), Error> {
18 let broadcast_scalar = P::broadcast(scalar);
19 values.split_half_mut(|lo, hi| {
20 (lo.as_mut(), hi.as_mut())
21 .into_par_iter()
22 .for_each(|(zero, one)| {
23 *zero += broadcast_scalar * (*one - *zero);
24 });
25 })?;
26
27 values.truncate(values.log_len() - 1);
28 Ok(())
29}
30
31pub fn binary_fold_high<P, DataOut, DataIn>(
45 values: &mut FieldBuffer<P, DataOut>,
46 tensor: &FieldBuffer<P, DataIn>,
47 bits: impl RandomAccessSequence<bool> + Sync,
48) -> Result<(), Error>
49where
50 P: PackedField,
51 DataOut: DerefMut<Target = [P]>,
52 DataIn: Deref<Target = [P]> + Sync,
53{
54 if !bits.len().is_power_of_two() {
55 return Err(Error::PowerOfTwoLengthRequired);
56 }
57
58 let values_log_len = values.log_len();
59 let width = P::WIDTH.min(values.len());
60
61 if 1 << (values_log_len + tensor.log_len()) != bits.len() {
62 return Err(Error::FoldLengthMismatch);
63 }
64
65 values
66 .as_mut()
67 .iter_mut()
68 .enumerate()
69 .for_each(|(i, packed)| {
70 *packed = P::from_scalars((0..width).map(|j| {
71 let scalar_index = i << P::LOG_WIDTH | j;
72 let mut acc = P::Scalar::ZERO;
73
74 for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
75 for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
76 let tensor_scalar_index = k << P::LOG_WIDTH | l;
77 if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
78 acc += tensor_scalar;
79 }
80 }
81 }
82
83 acc
84 }));
85 });
86
87 Ok(())
88}
89
90#[cfg(test)]
91mod tests {
92 use std::iter::repeat_with;
93
94 use rand::prelude::*;
95
96 use super::*;
97 use crate::{
98 multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
99 test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
100 };
101
102 type P = Packed128b;
103 type F = B128;
104
105 #[test]
106 fn test_fold_highest_var_inplace() {
107 let mut rng = StdRng::seed_from_u64(0);
108
109 let n_vars = 10;
110
111 let point = random_scalars::<F>(&mut rng, n_vars);
112 let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
113
114 let eval = evaluate(&multilinear, &point).unwrap();
115
116 for &scalar in point.iter().rev() {
117 fold_highest_var_inplace(&mut multilinear, scalar).unwrap();
118 }
119
120 assert_eq!(multilinear.get(0).unwrap(), eval);
121 }
122
123 fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
124 n_vars: usize,
125 tensor_n_vars: usize,
126 ) {
127 let mut rng = StdRng::seed_from_u64(0);
128
129 let point = random_scalars::<F>(&mut rng, tensor_n_vars);
130
131 let tensor = eq_ind_partial_eval::<P>(&point);
132
133 let bits = repeat_with(|| rng.random())
134 .take(1 << n_vars)
135 .collect::<Vec<bool>>();
136
137 let bits_scalars = bits
138 .iter()
139 .map(|&b| if b { F::ONE } else { F::ZERO })
140 .collect::<Vec<F>>();
141
142 let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars).unwrap();
143
144 let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
145 binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice()).unwrap();
146
147 for &scalar in point.iter().rev() {
148 fold_highest_var_inplace(&mut bits_buffer, scalar).unwrap();
149 }
150
151 assert_eq!(bits_buffer, binary_fold_result);
152 }
153
154 #[test]
155 fn test_binary_fold_high_conforms_to_regular_fold_high() {
156 for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
157 test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
158 }
159 }
160}