binius_math/multilinear/
evaluate.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField, field::FieldOps};
6use binius_utils::rayon::prelude::*;
7
8use crate::{
9 FieldBuffer,
10 inner_product::inner_product_buffers,
11 multilinear::{eq::eq_ind_partial_eval, fold::fold_highest_var_inplace},
12};
13
14pub fn evaluate<F, P, Data>(evals: &FieldBuffer<P, Data>, point: &[F]) -> F
34where
35 F: Field,
36 P: PackedField<Scalar = F>,
37 Data: Deref<Target = [P]>,
38{
39 assert_eq!(
40 point.len(),
41 evals.log_len(),
42 "precondition: point length must equal evals log length"
43 );
44
45 let first_half_len = (point.len() / 2).max(P::LOG_WIDTH).min(point.len());
47 let (first_coords, remaining_coords) = point.split_at(first_half_len);
48
49 let eq_tensor = eq_ind_partial_eval::<P>(first_coords);
51
52 if remaining_coords.is_empty() {
54 return inner_product_buffers(evals, &eq_tensor);
55 }
56
57 let log_chunk_size = first_half_len;
59
60 let scalars = evals
62 .chunks_par(log_chunk_size)
63 .map(|chunk| inner_product_buffers(&chunk, &eq_tensor))
64 .collect::<Vec<_>>();
65
66 let temp_buffer = FieldBuffer::<P>::from_values(&scalars);
68
69 evaluate_inplace(temp_buffer, remaining_coords)
71}
72
73pub fn evaluate_inplace<F, P, Data>(mut evals: FieldBuffer<P, Data>, coords: &[F]) -> F
91where
92 F: Field,
93 P: PackedField<Scalar = F>,
94 Data: DerefMut<Target = [P]>,
95{
96 assert_eq!(
97 coords.len(),
98 evals.log_len(),
99 "precondition: coords length must equal evals log length"
100 );
101
102 for &coord in coords.iter().rev() {
104 fold_highest_var_inplace(&mut evals, coord);
105 }
106
107 assert_eq!(evals.len(), 1);
108 evals.get(0)
109}
110
111pub fn evaluate_inplace_scalars<F: FieldOps>(
127 mut evals: impl DerefMut<Target = [F]>,
128 point: &[F],
129) -> F {
130 assert_eq!(evals.len(), 1 << point.len(), "precondition: evals length must be 2^point.len()");
131
132 for (log_half_len, point_i) in point.iter().enumerate().rev() {
133 let half_len = 1 << log_half_len;
134 for j in 0..half_len {
135 let delta = evals[j + half_len].clone() - evals[j].clone();
136 evals[j] += point_i.clone() * delta;
137 }
138 }
139 evals[0].clone()
140}
141
142#[cfg(test)]
143mod tests {
144 use rand::prelude::*;
145
146 use super::*;
147 use crate::{
148 inner_product::inner_product_par,
149 test_utils::{
150 B128, Packed128b, index_to_hypercube_point, random_field_buffer, random_scalars,
151 },
152 };
153
154 type P = Packed128b;
155 type F = B128;
156
157 #[test]
158 fn test_evaluate_consistency() {
159 fn evaluate_with_inner_product<F, P, Data>(evals: &FieldBuffer<P, Data>, point: &[F]) -> F
161 where
162 F: Field,
163 P: PackedField<Scalar = F>,
164 Data: Deref<Target = [P]>,
165 {
166 assert_eq!(point.len(), evals.log_len());
167
168 let eq_tensor = eq_ind_partial_eval::<P>(point);
170 inner_product_par(evals, &eq_tensor)
171 }
172
173 let mut rng = StdRng::seed_from_u64(0);
174
175 for log_n in [0, P::LOG_WIDTH - 1, P::LOG_WIDTH, 10] {
176 let buffer = random_field_buffer::<P>(&mut rng, log_n);
178 let point = random_scalars::<F>(&mut rng, log_n);
179
180 let result_inner_product = evaluate_with_inner_product(&buffer, &point);
182 let result_inplace = evaluate_inplace(buffer.clone(), &point);
183 let result_sqrt_memory = evaluate(&buffer, &point);
184
185 assert_eq!(result_inner_product, result_inplace);
187 assert_eq!(result_inner_product, result_sqrt_memory);
188 }
189 }
190
191 #[test]
192 fn test_evaluate_at_hypercube_indices() {
193 let mut rng = StdRng::seed_from_u64(0);
194
195 let log_n = 8;
197 let buffer = random_field_buffer::<F>(&mut rng, log_n);
198
199 for _ in 0..16 {
201 let index = (rng.next_u32() as usize) % (1 << log_n);
202 let point = index_to_hypercube_point::<F>(log_n, index);
203
204 let eval_result = evaluate(&buffer, &point);
206
207 let direct_value = buffer.get(index);
209
210 assert_eq!(eval_result, direct_value);
212 }
213 }
214
215 #[test]
216 fn test_evaluate_inplace_scalars_consistency() {
217 let mut rng = StdRng::seed_from_u64(0);
218
219 for log_n in [0, P::LOG_WIDTH - 1, P::LOG_WIDTH, 10] {
220 let buffer = random_field_buffer::<P>(&mut rng, log_n);
221 let point = random_scalars::<F>(&mut rng, log_n);
222
223 let result_inplace = evaluate_inplace(buffer.clone(), &point);
224
225 let scalar_evals = buffer.iter_scalars().collect::<Vec<_>>();
226 let result_scalar = evaluate_inplace_scalars(scalar_evals, &point);
227
228 assert_eq!(result_inplace, result_scalar, "mismatch at log_n={log_n}");
229 }
230 }
231
232 #[test]
233 fn test_linearity() {
234 let mut rng = StdRng::seed_from_u64(0);
235
236 let log_n = 8;
238 let buffer = random_field_buffer::<F>(&mut rng, log_n);
239 let mut point = random_scalars::<F>(&mut rng, log_n);
240
241 for coord_idx in 0..log_n {
243 let coord_vals = random_scalars::<F>(&mut rng, 3);
245
246 let evals: Vec<_> = coord_vals
248 .iter()
249 .map(|&coord_val| {
250 point[coord_idx] = coord_val;
251 evaluate(&buffer, &point)
252 })
253 .collect();
254
255 let x0 = coord_vals[0];
260 let x1 = coord_vals[1];
261 let x2 = coord_vals[2];
262 let y0 = evals[0];
263 let y1 = evals[1];
264 let y2 = evals[2];
265
266 let lhs = (y2 - y0) * (x1 - x0);
267 let rhs = (y1 - y0) * (x2 - x0);
268
269 assert_eq!(lhs, rhs);
270 }
271 }
272}