binius_math/multilinear/
evaluate.rs1use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::rayon::prelude::*;
7
8use crate::{
9 Error, FieldBuffer,
10 inner_product::inner_product_buffers,
11 multilinear::{eq::eq_ind_partial_eval, fold::fold_highest_var_inplace},
12};
13
14pub fn evaluate<F, P, Data>(evals: &FieldBuffer<P, Data>, point: &[F]) -> Result<F, Error>
30where
31 F: Field,
32 P: PackedField<Scalar = F>,
33 Data: Deref<Target = [P]>,
34{
35 if point.len() != evals.log_len() {
36 return Err(Error::IncorrectArgumentLength {
37 arg: "point".to_string(),
38 expected: evals.log_len(),
39 });
40 }
41
42 let first_half_len = (point.len() / 2).max(P::LOG_WIDTH).min(point.len());
44 let (first_coords, remaining_coords) = point.split_at(first_half_len);
45
46 let eq_tensor = eq_ind_partial_eval::<P>(first_coords);
48
49 if remaining_coords.is_empty() {
51 return Ok(inner_product_buffers(evals, &eq_tensor));
52 }
53
54 let log_chunk_size = first_half_len;
56
57 let scalars = evals
59 .chunks_par(log_chunk_size)?
60 .map(|chunk| inner_product_buffers(&chunk, &eq_tensor))
61 .collect::<Vec<_>>();
62
63 let temp_buffer = FieldBuffer::<P>::from_values(&scalars)?;
65
66 evaluate_inplace(temp_buffer, remaining_coords)
68}
69
70pub fn evaluate_inplace<F, P, Data>(
84 mut evals: FieldBuffer<P, Data>,
85 coords: &[F],
86) -> Result<F, Error>
87where
88 F: Field,
89 P: PackedField<Scalar = F>,
90 Data: DerefMut<Target = [P]>,
91{
92 if coords.len() != evals.log_len() {
93 return Err(Error::IncorrectArgumentLength {
94 arg: "coords".to_string(),
95 expected: evals.log_len(),
96 });
97 }
98
99 for &coord in coords.iter().rev() {
101 fold_highest_var_inplace(&mut evals, coord)?;
102 }
103
104 assert_eq!(evals.len(), 1);
105 Ok(evals.get(0).expect("evals.len() == 1"))
106}
107
108#[cfg(test)]
109mod tests {
110 use rand::{RngCore, SeedableRng, rngs::StdRng};
111
112 use super::*;
113 use crate::{
114 inner_product::inner_product_par,
115 test_utils::{
116 B128, Packed128b, index_to_hypercube_point, random_field_buffer, random_scalars,
117 },
118 };
119
120 type P = Packed128b;
121 type F = B128;
122
123 #[test]
124 fn test_evaluate_consistency() {
125 fn evaluate_with_inner_product<F, P, Data>(
127 evals: &FieldBuffer<P, Data>,
128 point: &[F],
129 ) -> Result<F, Error>
130 where
131 F: Field,
132 P: PackedField<Scalar = F>,
133 Data: Deref<Target = [P]>,
134 {
135 if point.len() != evals.log_len() {
136 return Err(Error::IncorrectArgumentLength {
137 arg: "coords".to_string(),
138 expected: evals.log_len(),
139 });
140 }
141
142 let eq_tensor = eq_ind_partial_eval::<P>(point);
144 let result = inner_product_par(evals, &eq_tensor);
145 Ok(result)
146 }
147
148 let mut rng = StdRng::seed_from_u64(0);
149
150 for log_n in [0, P::LOG_WIDTH - 1, P::LOG_WIDTH, 10] {
151 let buffer = random_field_buffer::<P>(&mut rng, log_n);
153 let point = random_scalars::<F>(&mut rng, log_n);
154
155 let result_inner_product = evaluate_with_inner_product(&buffer, &point).unwrap();
157 let result_inplace = evaluate_inplace(buffer.clone(), &point).unwrap();
158 let result_sqrt_memory = evaluate(&buffer, &point).unwrap();
159
160 assert_eq!(result_inner_product, result_inplace);
162 assert_eq!(result_inner_product, result_sqrt_memory);
163 }
164 }
165
166 #[test]
167 fn test_evaluate_at_hypercube_indices() {
168 let mut rng = StdRng::seed_from_u64(0);
169
170 let log_n = 8;
172 let buffer = random_field_buffer::<F>(&mut rng, log_n);
173
174 for _ in 0..16 {
176 let index = (rng.next_u32() as usize) % (1 << log_n);
177 let point = index_to_hypercube_point::<F>(log_n, index);
178
179 let eval_result = evaluate(&buffer, &point).unwrap();
181
182 let direct_value = buffer.get(index).unwrap();
184
185 assert_eq!(eval_result, direct_value);
187 }
188 }
189
190 #[test]
191 fn test_linearity() {
192 let mut rng = StdRng::seed_from_u64(0);
193
194 let log_n = 8;
196 let buffer = random_field_buffer::<F>(&mut rng, log_n);
197 let mut point = random_scalars::<F>(&mut rng, log_n);
198
199 for coord_idx in 0..log_n {
201 let coord_vals = random_scalars::<F>(&mut rng, 3);
203
204 let evals: Vec<_> = coord_vals
206 .iter()
207 .map(|&coord_val| {
208 point[coord_idx] = coord_val;
209 evaluate(&buffer, &point).unwrap()
210 })
211 .collect();
212
213 let x0 = coord_vals[0];
218 let x1 = coord_vals[1];
219 let x2 = coord_vals[2];
220 let y0 = evals[0];
221 let y1 = evals[1];
222 let y2 = evals[2];
223
224 let lhs = (y2 - y0) * (x1 - x0);
225 let rhs = (y1 - y0) * (x2 - x0);
226
227 assert_eq!(lhs, rhs);
228 }
229 }
230}