binius_math/
piecewise_multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::Field;
4use tracing::instrument;
5
6use crate::{extrapolate_line_scalar, Error};
7
8/// Evaluate a piecewise multilinear polynomial at a point, given the evaluations of the pieces.
9///
10/// A piecewise multilinear is defined by a sequence of multilinear polynomials, ordered from
11/// most number of variables to least. We define it over a boolean hypercube by identifying
12/// the smallest hypercube larger than the total number of hypercube evaluations of all the pieces,
13/// then concatenating the evaluations, flattened into vectors in little-endian order. The
14/// evaluation vector is right-padded with zeros up to the hypercube size of the containing cube.
15/// Equivalently, we can view the piecewise multilinear as defined by an inductive linear
16/// interpolation of pairs of polynomials.
17///
18/// This function takes a description of a piecewise multilinear polynomial and the evaluations
19/// of the pieces at prefixes of the given point, then evaluates the concatenated multilinear at
20/// the point.
21///
22/// ## Arguments
23///
24/// * `point` - the evaluation point. The length specifies the number of variables in the
25///     concatenated multilinear.
26/// * `n_pieces_by_vars` - the number of multilinear pieces, indexed by the number of variables.
27///     Entry at index `i` is the number of multilinears with `i` variables. The sum of
28///     `n_pieces_by_vars[i] * 2^i` for all indices `i` must be at most `2^n`, where `n` is the
29///     length of `point.`.
30/// * `piece_evals` - the evaluations of the multilinear pieces at the corresponding prefixes of
31///     `point`. The length must be equal to the sum of the values in `n_pieces_by_vars`. This must
32///     be in the *same order* as the corresponding multilinears are implicitly concatenated.
33///
34/// ## Example
35///
36/// Suppose we have three multilinear functions $f_0$, $f_1$, and $f_2$, which have $2$, $2$, and
37/// $1$ variables respectively. There exists a unique $4$-variate multilinear function $\tilde{f}$,
38/// which is defined by concatenating the evaluations of $f_0$, $f_1$, and $f_2$ along their
39/// respective defining Boolean hypercubes and then zero-padding. (I.e., we simply decree that the
40/// resulting list of length 16 is the ordered list of evaluations of a multilinear $\tilde{f}$
41/// on $\mathcal B_4$.) We wish to evaluate $\tilde{f}$ at the point $(r_0, r_1, r_2, r_3)$,
42/// and we are given the evaluations of $v_0:=f_0(r_0,r_1)$, $v_1:=f_1(r_0,r_1)$, and
43/// $v_2:=f_2(r_0)$. In this situation, we have the following:
44/// `n_pieces_by_vars` is `[0, 1, 2]`, and `piece_evals` is `[v_0, v_1, v_2]`.
45#[instrument(skip_all)]
46pub fn evaluate_piecewise_multilinear<F: Field>(
47	point: &[F],
48	n_pieces_by_vars: &[usize],
49	piece_evals: &mut [F],
50) -> Result<F, Error> {
51	// dimension of the big hypercube
52	let total_n_vars = point.len();
53	// number of multilinears
54	let num_polys = n_pieces_by_vars.iter().sum();
55	// total number of coefficients in all of the multilinears.
56	let total_length: usize = n_pieces_by_vars
57		.iter()
58		.enumerate()
59		.map(|(i, &n)| n << i)
60		.sum();
61	if total_length > 1 << total_n_vars {
62		return Err(Error::PiecewiseMultilinearTooLong {
63			total_length,
64			total_n_vars,
65		});
66	}
67	if piece_evals.len() != num_polys {
68		return Err(Error::PiecewiseMultilinearIncompatibleEvals {
69			expected: num_polys,
70			actual: piece_evals.len(),
71		});
72	}
73
74	// The logic is that we can iteratively compute the "root claim" via a "folding" procedure,
75	// reading from right to left.
76	// To demonstrate what is happening, suppose we have 3 multilinears f_0, f_1, f_2, of lengths
77	// 2, 2, 1 respectively.
78	// We want to evaluate the concatenation at point [r_0, r_1, r_2, r_3]. We are given the
79	// evaluations:
80	// v_0 := f_0(r_0, r_1), v_1 := f_1(r_0, r_1), and v_2 := f_2(r_0).
81	// we first replace v_2 with (1-r_1)*v_2 + r_1 * 0, which is implicitly the evaluation of the
82	// 2-variate multilinear corresponding to f_2(evals)||0, 0 at the point (r_0, r_1).
83	// we then replace v_0 with (1-r_2)*v_0 + r_2 * v_1 and v_1 with (1-r_2)*v_1 + r_2 * 0.
84	// this list of length 2 represents the evaluation claims of the trivariate multilinears
85	// corresponding to:
86	// 	1. f_0(evals)||f_1(evals); and
87	// 	2. f_2(evals)||0, 0, 0, 0
88	// at the point (r_0, r_1, r_2).
89	// finally, we fold our final list of length 2 via r_3 to obtain the root claim.
90	let mut index = piece_evals.len();
91	let mut n_to_fold = 0;
92	for (i, &point_i) in point.iter().enumerate() {
93		n_to_fold += n_pieces_by_vars.get(i).copied().unwrap_or(0);
94		fold_segment(&mut piece_evals[index - n_to_fold..index], point_i);
95		let n_folded_out = n_to_fold / 2;
96		index -= n_folded_out;
97		n_to_fold -= n_folded_out;
98	}
99
100	Ok(piece_evals[0])
101}
102
103/// Folds a sequence of pairs of field elements by linear extrapolation at a given point.
104///
105/// Given a list $[a_0,\ldots, a_{n-1}]$ and a parameter $r$, mutate to the list
106/// $[a_0(1-r) + a_1r, a_2(1-r) + a_3r, \ldots, 0, \ldots]$. If $n$ is odd, then the last
107/// potentially non-zero element is $a_{n-1}(1-r)$, at index $(n-1)/2$.
108fn fold_segment<F: Field>(segment: &mut [F], z: F) {
109	let n_full_pairs = segment.len() / 2;
110	for i in 0..n_full_pairs {
111		segment[i] = extrapolate_line_scalar(segment[2 * i], segment[2 * i + 1], z);
112	}
113	if segment.len() % 2 == 1 {
114		let i = segment.len() / 2;
115		segment[i] = extrapolate_line_scalar(segment[2 * i], F::ZERO, z);
116	}
117}
118
119#[cfg(test)]
120mod tests {
121	use std::iter::repeat_with;
122
123	use binius_field::BinaryField32b;
124	use binius_utils::checked_arithmetics::{log2_ceil_usize, log2_strict_usize};
125	use rand::{prelude::StdRng, SeedableRng};
126
127	use super::*;
128	use crate::{MultilinearExtension, MultilinearQuery};
129
130	fn test_piecewise_multilinear(number_of_variables: &[usize]) {
131		type F = BinaryField32b;
132		assert!(
133			number_of_variables.windows(2).all(|w| w[0] >= w[1]),
134			"Number of variables must be sorted in non-increasing order"
135		);
136
137		let mut n_pieces_by_vars = vec![0; number_of_variables[0] + 1];
138		for &n_vars in number_of_variables {
139			n_pieces_by_vars[n_vars] += 1;
140		}
141
142		let mut rng = StdRng::seed_from_u64(0);
143
144		// build random multilinears with the given number of variables.
145		let mut multilinear_coefficients: Vec<Vec<F>> = vec![];
146		for &n_vars in number_of_variables {
147			multilinear_coefficients.push(gen_random_multilinear(n_vars, &mut rng));
148		}
149
150		let total_num_coeffs: usize = number_of_variables.iter().map(|&n| 1 << n).sum();
151		let total_number_of_variables = log2_ceil_usize(total_num_coeffs);
152		let zero_padding_length = (1 << total_number_of_variables) - total_num_coeffs;
153
154		// zero-padded concatenated multilinear
155		let concatenated_multilinear_coeffs = multilinear_coefficients
156			.iter()
157			.flat_map(|coeffs| coeffs.iter())
158			.copied()
159			.chain(repeat_with(|| F::ZERO).take(zero_padding_length))
160			.collect::<Vec<_>>();
161
162		let concatenated_multilinear = MultilinearExtension::<F, Vec<F>>::new(
163			total_number_of_variables,
164			concatenated_multilinear_coeffs,
165		)
166		.unwrap();
167
168		let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
169			.take(total_number_of_variables)
170			.collect::<Vec<_>>();
171
172		let mlq_eval_point = MultilinearQuery::<F>::expand(&eval_point);
173		// compute individual claims.
174		let mut piece_evals =
175			eval_multilinears_at_common_prefix(&multilinear_coefficients, &eval_point);
176
177		// compute claim of concatenated polynomial directly.
178		let concatenate_and_evaluate = concatenated_multilinear.evaluate(&mlq_eval_point).unwrap();
179
180		// compute claim of concatenated polynomial via `evaluation_piecewise_multilinear`
181		let compute_via_piecewise = evaluate_piecewise_multilinear(
182			eval_point.as_slice(),
183			&n_pieces_by_vars,
184			&mut piece_evals,
185		)
186		.unwrap();
187		assert_eq!(concatenate_and_evaluate, compute_via_piecewise);
188	}
189
190	#[test]
191	fn test_piecewise_multilinear_4_4() {
192		test_piecewise_multilinear(&[4, 4]);
193	}
194
195	#[test]
196	fn test_piecewise_multilinear_5_3_2() {
197		test_piecewise_multilinear(&[5, 3, 2]);
198	}
199
200	#[test]
201	fn test_piecewise_multilinear_5_5_3_3_2_2_2() {
202		test_piecewise_multilinear(&[5, 5, 3, 3, 2, 2, 2]);
203	}
204
205	#[test]
206	fn test_piecewise_multilinear_2_2_1() {
207		test_piecewise_multilinear(&[2, 2, 1]);
208	}
209
210	#[test]
211	fn test_piecewise_multilinear_3_1_0_0_0() {
212		test_piecewise_multilinear(&[3, 1, 0, 0, 0]);
213	}
214
215	fn gen_random_multilinear<F: Field>(n_vars: usize, mut rng: &mut StdRng) -> Vec<F> {
216		repeat_with(|| F::random(&mut rng))
217			.take(1 << n_vars)
218			.collect::<Vec<_>>()
219	}
220
221	fn eval_multilinears_at_common_prefix<F: Field>(
222		multilinears: &[Vec<F>],
223		prefix: &[F],
224	) -> Vec<F> {
225		let mut result = Vec::new();
226		for multilinear_coeffs in multilinears {
227			let log_len = log2_strict_usize(multilinear_coeffs.len());
228			let multilinear =
229				MultilinearExtension::<F, Vec<F>>::new(log_len, multilinear_coeffs.clone())
230					.unwrap();
231			let mlq_eval_point = MultilinearQuery::<F>::expand(&prefix[..log_len]);
232			let eval = multilinear.evaluate(&mlq_eval_point).unwrap();
233			result.push(eval);
234		}
235		result
236	}
237
238	#[test]
239	fn test_fold_segment_basic() {
240		let s0 = BinaryField32b::from(2);
241		let s1 = BinaryField32b::from(4);
242		let s2 = BinaryField32b::from(6);
243		let s3 = BinaryField32b::from(8);
244
245		let mut segment = vec![s0, s1, s2, s3];
246		let z = BinaryField32b::from(3);
247
248		fold_segment(&mut segment, z);
249
250		assert_eq!(
251			segment,
252			vec![
253				extrapolate_line_scalar(s0, s1, z),
254				extrapolate_line_scalar(s2, s3, z),
255				s2,
256				s3
257			]
258		);
259	}
260
261	#[test]
262	fn test_fold_segment_single_element() {
263		let s0 = BinaryField32b::from(5);
264		let mut segment = vec![s0];
265		let z = BinaryField32b::from(3);
266
267		fold_segment(&mut segment, z);
268
269		assert_eq!(segment, vec![extrapolate_line_scalar(s0, BinaryField32b::ZERO, z),]);
270	}
271
272	#[test]
273	fn test_fold_segment_odd_length() {
274		let s0 = BinaryField32b::from(1);
275		let s1 = BinaryField32b::from(3);
276		let s2 = BinaryField32b::from(5);
277
278		let mut segment = vec![s0, s1, s2];
279		let z = BinaryField32b::from(2);
280
281		fold_segment(&mut segment, z);
282
283		assert_eq!(
284			segment,
285			vec![
286				extrapolate_line_scalar(s0, s1, z),
287				extrapolate_line_scalar(s2, BinaryField32b::ZERO, z),
288				s2
289			]
290		);
291	}
292}