1use binius_field::{packed::set_packed_slice, BinaryField128b, BinaryField1b, Field, PackedField};
4use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{bail, DeserializeBytes};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
23pub struct SelectRow {
24 n_vars: usize,
25 index: usize,
26}
27
28inventory::submit! {
29 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
30 "SelectRow",
31 |buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?))
32 )
33}
34
35impl SelectRow {
36 pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
37 if index >= (1 << n_vars) {
38 bail!(Error::ArgumentRangeError {
39 arg: "index".into(),
40 range: 0..(1 << n_vars),
41 })
42 }
43 Ok(Self { n_vars, index })
44 }
45
46 pub fn multilinear_extension<P: PackedField<Scalar = BinaryField1b>>(
47 &self,
48 ) -> Result<MultilinearExtension<P>, Error> {
49 if self.n_vars < P::LOG_WIDTH {
50 bail!(Error::PackedFieldNotFilled {
51 length: 1 << self.n_vars,
52 packed_width: 1 << P::LOG_WIDTH,
53 });
54 }
55 let mut result = vec![P::zero(); 1 << (self.n_vars - P::LOG_WIDTH)];
56 set_packed_slice(&mut result, self.index, P::Scalar::ONE);
57 Ok(MultilinearExtension::from_values(result)?)
58 }
59}
60
61#[erased_serialize_bytes]
62impl<F: Field> MultivariatePoly<F> for SelectRow {
63 fn degree(&self) -> usize {
64 self.n_vars
65 }
66
67 fn n_vars(&self) -> usize {
68 self.n_vars
69 }
70
71 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
72 let n_vars = MultivariatePoly::<F>::n_vars(self);
73 if query.len() != n_vars {
74 bail!(Error::IncorrectQuerySize { expected: n_vars });
75 }
76 let mut k = self.index;
77 let mut result = F::ONE;
78 for q in query {
79 if k & 1 == 1 {
80 result *= q;
82 } else {
83 result *= F::ONE - q;
85 }
86 k >>= 1;
87 }
88 Ok(result)
89 }
90
91 fn binary_tower_level(&self) -> usize {
92 0
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use binius_field::{
99 BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
100 };
101 use binius_utils::felts;
102
103 use super::SelectRow;
104 use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
105
106 #[test]
107 fn test_select_row_evals_without_packing_simple_cases() {
108 assert_eq!(select_row_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[1, 0, 0, 0]));
109 assert_eq!(select_row_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[0, 1, 0, 0]));
110 assert_eq!(select_row_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[0, 0, 1, 0]));
111 assert_eq!(select_row_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[0, 0, 0, 1]));
112 assert_eq!(
113 select_row_evals::<BinaryField1b>(3, 0),
114 felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
115 );
116 assert_eq!(
117 select_row_evals::<BinaryField1b>(3, 1),
118 felts!(BinaryField1b[0, 1, 0, 0, 0, 0, 0, 0])
119 );
120 assert_eq!(
121 select_row_evals::<BinaryField1b>(3, 2),
122 felts!(BinaryField1b[0, 0, 1, 0, 0, 0, 0, 0])
123 );
124 assert_eq!(
125 select_row_evals::<BinaryField1b>(3, 3),
126 felts!(BinaryField1b[0, 0, 0, 1, 0, 0, 0, 0])
127 );
128 assert_eq!(
129 select_row_evals::<BinaryField1b>(3, 4),
130 felts!(BinaryField1b[0, 0, 0, 0, 1, 0, 0, 0])
131 );
132 assert_eq!(
133 select_row_evals::<BinaryField1b>(3, 5),
134 felts!(BinaryField1b[0, 0, 0, 0, 0, 1, 0, 0])
135 );
136 assert_eq!(
137 select_row_evals::<BinaryField1b>(3, 6),
138 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 1, 0])
139 );
140 assert_eq!(
141 select_row_evals::<BinaryField1b>(3, 7),
142 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 1])
143 );
144 }
145
146 #[test]
147 fn test_select_row_evals_without_packing() {
148 assert_eq!(
149 select_row_evals::<BinaryField1b>(9, 314),
150 packed_slice::<BinaryField1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
151 );
152 assert_eq!(
153 select_row_evals::<BinaryField1b>(10, 555),
154 packed_slice::<BinaryField1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
155 );
156 assert_eq!(
157 select_row_evals::<BinaryField1b>(11, 1),
158 packed_slice::<BinaryField1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
159 );
160 }
161
162 #[test]
163 fn test_select_row_evals_packing_128() {
164 assert_eq!(
165 select_row_evals::<PackedBinaryField128x1b>(9, 314),
166 packed_slice::<PackedBinaryField128x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
167 );
168 assert_eq!(
169 select_row_evals::<PackedBinaryField128x1b>(10, 555),
170 packed_slice::<PackedBinaryField128x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
171 );
172 assert_eq!(
173 select_row_evals::<PackedBinaryField128x1b>(11, 1),
174 packed_slice::<PackedBinaryField128x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
175 );
176 }
177
178 #[test]
179 fn test_select_row_evals_packing_256() {
180 assert_eq!(
181 select_row_evals::<PackedBinaryField256x1b>(9, 314),
182 packed_slice::<PackedBinaryField256x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
183 );
184 assert_eq!(
185 select_row_evals::<PackedBinaryField256x1b>(10, 555),
186 packed_slice::<PackedBinaryField256x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
187 );
188 assert_eq!(
189 select_row_evals::<PackedBinaryField256x1b>(11, 1),
190 packed_slice::<PackedBinaryField256x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
191 );
192 }
193
194 #[test]
195 fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
196 for n_vars in 1..5 {
197 for index in 0..(1 << n_vars) {
198 let select_row = SelectRow::new(n_vars, index).unwrap();
199 assert_eq!(
200 hypercube_evals_from_oracle::<BinaryField1b>(&select_row),
201 select_row
202 .multilinear_extension::<BinaryField1b>()
203 .unwrap()
204 .evals()
205 );
206 }
207 }
208 }
209
210 fn select_row_evals<P>(n_vars: usize, index: usize) -> Vec<P>
211 where
212 P: PackedField<Scalar = BinaryField1b>,
213 {
214 SelectRow::new(n_vars, index)
215 .unwrap()
216 .multilinear_extension::<P>()
217 .unwrap()
218 .evals()
219 .to_vec()
220 }
221}