1use binius_field::{BinaryField1b, BinaryField128b, Field, PackedField, packed::set_packed_slice};
4use binius_macros::{DeserializeBytes, SerializeBytes, erased_serialize_bytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{DeserializeBytes, bail};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
23pub struct SelectRow {
24 n_vars: usize,
25 index: usize,
26}
27
28inventory::submit! {
29 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
30 "SelectRow",
31 |buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?))
32 )
33}
34
35impl SelectRow {
36 pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
37 if index >= (1 << n_vars) {
38 bail!(Error::ArgumentRangeError {
39 arg: "index".into(),
40 range: 0..(1 << n_vars),
41 })
42 }
43 Ok(Self { n_vars, index })
44 }
45
46 pub fn multilinear_extension<P: PackedField<Scalar = BinaryField1b>>(
47 &self,
48 ) -> Result<MultilinearExtension<P>, Error> {
49 if self.n_vars < P::LOG_WIDTH {
50 bail!(Error::PackedFieldNotFilled {
51 length: 1 << self.n_vars,
52 packed_width: 1 << P::LOG_WIDTH,
53 });
54 }
55 let mut result = vec![P::zero(); 1 << (self.n_vars - P::LOG_WIDTH)];
56 set_packed_slice(&mut result, self.index, P::Scalar::ONE);
57 Ok(MultilinearExtension::from_values(result)?)
58 }
59}
60
61#[erased_serialize_bytes]
62impl<F: Field> MultivariatePoly<F> for SelectRow {
63 fn degree(&self) -> usize {
64 self.n_vars
65 }
66
67 fn n_vars(&self) -> usize {
68 self.n_vars
69 }
70
71 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
72 let n_vars = MultivariatePoly::<F>::n_vars(self);
73 if query.len() != n_vars {
74 bail!(Error::IncorrectQuerySize {
75 expected: n_vars,
76 actual: query.len()
77 });
78 }
79 let mut k = self.index;
80 let mut result = F::ONE;
81 for q in query {
82 if k & 1 == 1 {
83 result *= q;
85 } else {
86 result *= F::ONE - q;
88 }
89 k >>= 1;
90 }
91 Ok(result)
92 }
93
94 fn binary_tower_level(&self) -> usize {
95 0
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use binius_field::{
102 BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
103 };
104 use binius_utils::felts;
105
106 use super::SelectRow;
107 use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
108
109 #[test]
110 fn test_select_row_evals_without_packing_simple_cases() {
111 assert_eq!(select_row_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[1, 0, 0, 0]));
112 assert_eq!(select_row_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[0, 1, 0, 0]));
113 assert_eq!(select_row_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[0, 0, 1, 0]));
114 assert_eq!(select_row_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[0, 0, 0, 1]));
115 assert_eq!(
116 select_row_evals::<BinaryField1b>(3, 0),
117 felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
118 );
119 assert_eq!(
120 select_row_evals::<BinaryField1b>(3, 1),
121 felts!(BinaryField1b[0, 1, 0, 0, 0, 0, 0, 0])
122 );
123 assert_eq!(
124 select_row_evals::<BinaryField1b>(3, 2),
125 felts!(BinaryField1b[0, 0, 1, 0, 0, 0, 0, 0])
126 );
127 assert_eq!(
128 select_row_evals::<BinaryField1b>(3, 3),
129 felts!(BinaryField1b[0, 0, 0, 1, 0, 0, 0, 0])
130 );
131 assert_eq!(
132 select_row_evals::<BinaryField1b>(3, 4),
133 felts!(BinaryField1b[0, 0, 0, 0, 1, 0, 0, 0])
134 );
135 assert_eq!(
136 select_row_evals::<BinaryField1b>(3, 5),
137 felts!(BinaryField1b[0, 0, 0, 0, 0, 1, 0, 0])
138 );
139 assert_eq!(
140 select_row_evals::<BinaryField1b>(3, 6),
141 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 1, 0])
142 );
143 assert_eq!(
144 select_row_evals::<BinaryField1b>(3, 7),
145 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 1])
146 );
147 }
148
149 #[test]
150 fn test_select_row_evals_without_packing() {
151 assert_eq!(
152 select_row_evals::<BinaryField1b>(9, 314),
153 packed_slice::<BinaryField1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
154 );
155 assert_eq!(
156 select_row_evals::<BinaryField1b>(10, 555),
157 packed_slice::<BinaryField1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
158 );
159 assert_eq!(
160 select_row_evals::<BinaryField1b>(11, 1),
161 packed_slice::<BinaryField1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
162 );
163 }
164
165 #[test]
166 fn test_select_row_evals_packing_128() {
167 assert_eq!(
168 select_row_evals::<PackedBinaryField128x1b>(9, 314),
169 packed_slice::<PackedBinaryField128x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
170 );
171 assert_eq!(
172 select_row_evals::<PackedBinaryField128x1b>(10, 555),
173 packed_slice::<PackedBinaryField128x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
174 );
175 assert_eq!(
176 select_row_evals::<PackedBinaryField128x1b>(11, 1),
177 packed_slice::<PackedBinaryField128x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
178 );
179 }
180
181 #[test]
182 fn test_select_row_evals_packing_256() {
183 assert_eq!(
184 select_row_evals::<PackedBinaryField256x1b>(9, 314),
185 packed_slice::<PackedBinaryField256x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
186 );
187 assert_eq!(
188 select_row_evals::<PackedBinaryField256x1b>(10, 555),
189 packed_slice::<PackedBinaryField256x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
190 );
191 assert_eq!(
192 select_row_evals::<PackedBinaryField256x1b>(11, 1),
193 packed_slice::<PackedBinaryField256x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
194 );
195 }
196
197 #[test]
198 fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
199 for n_vars in 1..5 {
200 for index in 0..(1 << n_vars) {
201 let select_row = SelectRow::new(n_vars, index).unwrap();
202 assert_eq!(
203 hypercube_evals_from_oracle::<BinaryField1b>(&select_row),
204 select_row
205 .multilinear_extension::<BinaryField1b>()
206 .unwrap()
207 .evals()
208 );
209 }
210 }
211 }
212
213 fn select_row_evals<P>(n_vars: usize, index: usize) -> Vec<P>
214 where
215 P: PackedField<Scalar = BinaryField1b>,
216 {
217 SelectRow::new(n_vars, index)
218 .unwrap()
219 .multilinear_extension::<P>()
220 .unwrap()
221 .evals()
222 .to_vec()
223 }
224}