binius_math/
piecewise_multilinear.rs1use binius_field::Field;
4use tracing::instrument;
5
6use crate::{extrapolate_line_scalar, Error};
7
8#[instrument(skip_all)]
46pub fn evaluate_piecewise_multilinear<F: Field>(
47 point: &[F],
48 n_pieces_by_vars: &[usize],
49 piece_evals: &mut [F],
50) -> Result<F, Error> {
51 let total_n_vars = point.len();
53 let num_polys = n_pieces_by_vars.iter().sum();
55 let total_length: usize = n_pieces_by_vars
57 .iter()
58 .enumerate()
59 .map(|(i, &n)| n << i)
60 .sum();
61 if total_length > 1 << total_n_vars {
62 return Err(Error::PiecewiseMultilinearTooLong {
63 total_length,
64 total_n_vars,
65 });
66 }
67 if piece_evals.len() != num_polys {
68 return Err(Error::PiecewiseMultilinearIncompatibleEvals {
69 expected: num_polys,
70 actual: piece_evals.len(),
71 });
72 }
73
74 let mut index = piece_evals.len();
91 let mut n_to_fold = 0;
92 for (i, &point_i) in point.iter().enumerate() {
93 n_to_fold += n_pieces_by_vars.get(i).copied().unwrap_or(0);
94 fold_segment(&mut piece_evals[index - n_to_fold..index], point_i);
95 let n_folded_out = n_to_fold / 2;
96 index -= n_folded_out;
97 n_to_fold -= n_folded_out;
98 }
99
100 Ok(piece_evals[0])
101}
102
103fn fold_segment<F: Field>(segment: &mut [F], z: F) {
109 let n_full_pairs = segment.len() / 2;
110 for i in 0..n_full_pairs {
111 segment[i] = extrapolate_line_scalar(segment[2 * i], segment[2 * i + 1], z);
112 }
113 if segment.len() % 2 == 1 {
114 let i = segment.len() / 2;
115 segment[i] = extrapolate_line_scalar(segment[2 * i], F::ZERO, z);
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use std::iter::repeat_with;
122
123 use binius_field::BinaryField32b;
124 use binius_utils::checked_arithmetics::{log2_ceil_usize, log2_strict_usize};
125 use rand::{prelude::StdRng, SeedableRng};
126
127 use super::*;
128 use crate::{MultilinearExtension, MultilinearQuery};
129
130 fn test_piecewise_multilinear(number_of_variables: &[usize]) {
131 type F = BinaryField32b;
132 assert!(
133 number_of_variables.windows(2).all(|w| w[0] >= w[1]),
134 "Number of variables must be sorted in non-increasing order"
135 );
136
137 let mut n_pieces_by_vars = vec![0; number_of_variables[0] + 1];
138 for &n_vars in number_of_variables {
139 n_pieces_by_vars[n_vars] += 1;
140 }
141
142 let mut rng = StdRng::seed_from_u64(0);
143
144 let mut multilinear_coefficients: Vec<Vec<F>> = vec![];
146 for &n_vars in number_of_variables {
147 multilinear_coefficients.push(gen_random_multilinear(n_vars, &mut rng));
148 }
149
150 let total_num_coeffs: usize = number_of_variables.iter().map(|&n| 1 << n).sum();
151 let total_number_of_variables = log2_ceil_usize(total_num_coeffs);
152 let zero_padding_length = (1 << total_number_of_variables) - total_num_coeffs;
153
154 let concatenated_multilinear_coeffs = multilinear_coefficients
156 .iter()
157 .flat_map(|coeffs| coeffs.iter())
158 .copied()
159 .chain(repeat_with(|| F::ZERO).take(zero_padding_length))
160 .collect::<Vec<_>>();
161
162 let concatenated_multilinear = MultilinearExtension::<F, Vec<F>>::new(
163 total_number_of_variables,
164 concatenated_multilinear_coeffs,
165 )
166 .unwrap();
167
168 let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
169 .take(total_number_of_variables)
170 .collect::<Vec<_>>();
171
172 let mlq_eval_point = MultilinearQuery::<F>::expand(&eval_point);
173 let mut piece_evals =
175 eval_multilinears_at_common_prefix(&multilinear_coefficients, &eval_point);
176
177 let concatenate_and_evaluate = concatenated_multilinear.evaluate(&mlq_eval_point).unwrap();
179
180 let compute_via_piecewise = evaluate_piecewise_multilinear(
182 eval_point.as_slice(),
183 &n_pieces_by_vars,
184 &mut piece_evals,
185 )
186 .unwrap();
187 assert_eq!(concatenate_and_evaluate, compute_via_piecewise);
188 }
189
190 #[test]
191 fn test_piecewise_multilinear_4_4() {
192 test_piecewise_multilinear(&[4, 4]);
193 }
194
195 #[test]
196 fn test_piecewise_multilinear_5_3_2() {
197 test_piecewise_multilinear(&[5, 3, 2]);
198 }
199
200 #[test]
201 fn test_piecewise_multilinear_5_5_3_3_2_2_2() {
202 test_piecewise_multilinear(&[5, 5, 3, 3, 2, 2, 2]);
203 }
204
205 #[test]
206 fn test_piecewise_multilinear_2_2_1() {
207 test_piecewise_multilinear(&[2, 2, 1]);
208 }
209
210 #[test]
211 fn test_piecewise_multilinear_3_1_0_0_0() {
212 test_piecewise_multilinear(&[3, 1, 0, 0, 0]);
213 }
214
215 fn gen_random_multilinear<F: Field>(n_vars: usize, mut rng: &mut StdRng) -> Vec<F> {
216 repeat_with(|| F::random(&mut rng))
217 .take(1 << n_vars)
218 .collect::<Vec<_>>()
219 }
220
221 fn eval_multilinears_at_common_prefix<F: Field>(
222 multilinears: &[Vec<F>],
223 prefix: &[F],
224 ) -> Vec<F> {
225 let mut result = Vec::new();
226 for multilinear_coeffs in multilinears {
227 let log_len = log2_strict_usize(multilinear_coeffs.len());
228 let multilinear =
229 MultilinearExtension::<F, Vec<F>>::new(log_len, multilinear_coeffs.clone())
230 .unwrap();
231 let mlq_eval_point = MultilinearQuery::<F>::expand(&prefix[..log_len]);
232 let eval = multilinear.evaluate(&mlq_eval_point).unwrap();
233 result.push(eval);
234 }
235 result
236 }
237
238 #[test]
239 fn test_fold_segment_basic() {
240 let s0 = BinaryField32b::from(2);
241 let s1 = BinaryField32b::from(4);
242 let s2 = BinaryField32b::from(6);
243 let s3 = BinaryField32b::from(8);
244
245 let mut segment = vec![s0, s1, s2, s3];
246 let z = BinaryField32b::from(3);
247
248 fold_segment(&mut segment, z);
249
250 assert_eq!(
251 segment,
252 vec![
253 extrapolate_line_scalar(s0, s1, z),
254 extrapolate_line_scalar(s2, s3, z),
255 s2,
256 s3
257 ]
258 );
259 }
260
261 #[test]
262 fn test_fold_segment_single_element() {
263 let s0 = BinaryField32b::from(5);
264 let mut segment = vec![s0];
265 let z = BinaryField32b::from(3);
266
267 fold_segment(&mut segment, z);
268
269 assert_eq!(segment, vec![extrapolate_line_scalar(s0, BinaryField32b::ZERO, z),]);
270 }
271
272 #[test]
273 fn test_fold_segment_odd_length() {
274 let s0 = BinaryField32b::from(1);
275 let s1 = BinaryField32b::from(3);
276 let s2 = BinaryField32b::from(5);
277
278 let mut segment = vec![s0, s1, s2];
279 let z = BinaryField32b::from(2);
280
281 fold_segment(&mut segment, z);
282
283 assert_eq!(
284 segment,
285 vec![
286 extrapolate_line_scalar(s0, s1, z),
287 extrapolate_line_scalar(s2, BinaryField32b::ZERO, z),
288 s2
289 ]
290 );
291 }
292}