binius_math/multilinear/
evaluate.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::{
9 FieldBuffer,
10 inner_product::inner_product_buffers,
11 multilinear::{eq::eq_ind_partial_eval, fold::fold_highest_var_inplace},
12};
13
14pub fn evaluate<F, P, Data>(evals: &FieldBuffer<P, Data>, point: &[F]) -> F
34where
35 F: Field,
36 P: PackedField<Scalar = F>,
37 Data: Deref<Target = [P]>,
38{
39 assert_eq!(
40 point.len(),
41 evals.log_len(),
42 "precondition: point length must equal evals log length"
43 );
44
45 let first_half_len = (point.len() / 2).max(P::LOG_WIDTH).min(point.len());
47 let (first_coords, remaining_coords) = point.split_at(first_half_len);
48
49 let eq_tensor = eq_ind_partial_eval::<P>(first_coords);
51
52 if remaining_coords.is_empty() {
54 return inner_product_buffers(evals, &eq_tensor);
55 }
56
57 let log_chunk_size = first_half_len;
59
60 let scalars = evals
62 .chunks_par(log_chunk_size)
63 .map(|chunk| inner_product_buffers(&chunk, &eq_tensor))
64 .collect::<Vec<_>>();
65
66 let temp_buffer = FieldBuffer::<P>::from_values(&scalars);
68
69 evaluate_inplace(temp_buffer, remaining_coords)
71}
72
73pub fn evaluate_inplace<F, P, Data>(mut evals: FieldBuffer<P, Data>, coords: &[F]) -> F
91where
92 F: Field,
93 P: PackedField<Scalar = F>,
94 Data: DerefMut<Target = [P]>,
95{
96 assert_eq!(
97 coords.len(),
98 evals.log_len(),
99 "precondition: coords length must equal evals log length"
100 );
101
102 for &coord in coords.iter().rev() {
104 fold_highest_var_inplace(&mut evals, coord);
105 }
106
107 assert_eq!(evals.len(), 1);
108 evals.get(0)
109}
110
111#[cfg(test)]
112mod tests {
113 use rand::{RngCore, SeedableRng, rngs::StdRng};
114
115 use super::*;
116 use crate::{
117 inner_product::inner_product_par,
118 test_utils::{
119 B128, Packed128b, index_to_hypercube_point, random_field_buffer, random_scalars,
120 },
121 };
122
123 type P = Packed128b;
124 type F = B128;
125
126 #[test]
127 fn test_evaluate_consistency() {
128 fn evaluate_with_inner_product<F, P, Data>(evals: &FieldBuffer<P, Data>, point: &[F]) -> F
130 where
131 F: Field,
132 P: PackedField<Scalar = F>,
133 Data: Deref<Target = [P]>,
134 {
135 assert_eq!(point.len(), evals.log_len());
136
137 let eq_tensor = eq_ind_partial_eval::<P>(point);
139 inner_product_par(evals, &eq_tensor)
140 }
141
142 let mut rng = StdRng::seed_from_u64(0);
143
144 for log_n in [0, P::LOG_WIDTH - 1, P::LOG_WIDTH, 10] {
145 let buffer = random_field_buffer::<P>(&mut rng, log_n);
147 let point = random_scalars::<F>(&mut rng, log_n);
148
149 let result_inner_product = evaluate_with_inner_product(&buffer, &point);
151 let result_inplace = evaluate_inplace(buffer.clone(), &point);
152 let result_sqrt_memory = evaluate(&buffer, &point);
153
154 assert_eq!(result_inner_product, result_inplace);
156 assert_eq!(result_inner_product, result_sqrt_memory);
157 }
158 }
159
160 #[test]
161 fn test_evaluate_at_hypercube_indices() {
162 let mut rng = StdRng::seed_from_u64(0);
163
164 let log_n = 8;
166 let buffer = random_field_buffer::<F>(&mut rng, log_n);
167
168 for _ in 0..16 {
170 let index = (rng.next_u32() as usize) % (1 << log_n);
171 let point = index_to_hypercube_point::<F>(log_n, index);
172
173 let eval_result = evaluate(&buffer, &point);
175
176 let direct_value = buffer.get(index);
178
179 assert_eq!(eval_result, direct_value);
181 }
182 }
183
184 #[test]
185 fn test_linearity() {
186 let mut rng = StdRng::seed_from_u64(0);
187
188 let log_n = 8;
190 let buffer = random_field_buffer::<F>(&mut rng, log_n);
191 let mut point = random_scalars::<F>(&mut rng, log_n);
192
193 for coord_idx in 0..log_n {
195 let coord_vals = random_scalars::<F>(&mut rng, 3);
197
198 let evals: Vec<_> = coord_vals
200 .iter()
201 .map(|&coord_val| {
202 point[coord_idx] = coord_val;
203 evaluate(&buffer, &point)
204 })
205 .collect();
206
207 let x0 = coord_vals[0];
212 let x1 = coord_vals[1];
213 let x2 = coord_vals[2];
214 let y0 = evals[0];
215 let y1 = evals[1];
216 let y2 = evals[2];
217
218 let lhs = (y2 - y0) * (x1 - x0);
219 let rhs = (y1 - y0) * (x2 - x0);
220
221 assert_eq!(lhs, rhs);
222 }
223 }
224}