binius_core/transparent/
select_row.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{packed::set_packed_slice, BinaryField128b, BinaryField1b, Field, PackedField};
4use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{bail, DeserializeBytes};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10/// Represents a multilinear F2-polynomial whose evaluations over the hypercube is 1 at
11/// a specific hypercube index, and 0 everywhere else.
12///
13/// ```txt
14///     (1 << n_vars)
15/// <-------------------->
16/// 0,0 .. 0,0,1,0, .. 0,0
17///            ^
18///            index of 1
19/// ```
20///
21/// This is useful for defining boundary constraints
22#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
23pub struct SelectRow {
24	n_vars: usize,
25	index: usize,
26}
27
28inventory::submit! {
29	<dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
30		"SelectRow",
31		|buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?))
32	)
33}
34
35impl SelectRow {
36	pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
37		if index >= (1 << n_vars) {
38			bail!(Error::ArgumentRangeError {
39				arg: "index".into(),
40				range: 0..(1 << n_vars),
41			})
42		}
43		Ok(Self { n_vars, index })
44	}
45
46	pub fn multilinear_extension<P: PackedField<Scalar = BinaryField1b>>(
47		&self,
48	) -> Result<MultilinearExtension<P>, Error> {
49		if self.n_vars < P::LOG_WIDTH {
50			bail!(Error::PackedFieldNotFilled {
51				length: 1 << self.n_vars,
52				packed_width: 1 << P::LOG_WIDTH,
53			});
54		}
55		let mut result = vec![P::zero(); 1 << (self.n_vars - P::LOG_WIDTH)];
56		set_packed_slice(&mut result, self.index, P::Scalar::ONE);
57		Ok(MultilinearExtension::from_values(result)?)
58	}
59}
60
61#[erased_serialize_bytes]
62impl<F: Field> MultivariatePoly<F> for SelectRow {
63	fn degree(&self) -> usize {
64		self.n_vars
65	}
66
67	fn n_vars(&self) -> usize {
68		self.n_vars
69	}
70
71	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
72		let n_vars = MultivariatePoly::<F>::n_vars(self);
73		if query.len() != n_vars {
74			bail!(Error::IncorrectQuerySize { expected: n_vars });
75		}
76		let mut k = self.index;
77		let mut result = F::ONE;
78		for q in query {
79			if k & 1 == 1 {
80				// interpolate a line that is 0 at 0 and `result` at 1, at the point q
81				result *= q;
82			} else {
83				// interpolate a line that is `result` at 0 and 0 at 1, at the point q
84				result *= F::ONE - q;
85			}
86			k >>= 1;
87		}
88		Ok(result)
89	}
90
91	fn binary_tower_level(&self) -> usize {
92		0
93	}
94}
95
96#[cfg(test)]
97mod tests {
98	use binius_field::{
99		BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
100	};
101	use binius_utils::felts;
102
103	use super::SelectRow;
104	use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
105
106	#[test]
107	fn test_select_row_evals_without_packing_simple_cases() {
108		assert_eq!(select_row_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[1, 0, 0, 0]));
109		assert_eq!(select_row_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[0, 1, 0, 0]));
110		assert_eq!(select_row_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[0, 0, 1, 0]));
111		assert_eq!(select_row_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[0, 0, 0, 1]));
112		assert_eq!(
113			select_row_evals::<BinaryField1b>(3, 0),
114			felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
115		);
116		assert_eq!(
117			select_row_evals::<BinaryField1b>(3, 1),
118			felts!(BinaryField1b[0, 1, 0, 0, 0, 0, 0, 0])
119		);
120		assert_eq!(
121			select_row_evals::<BinaryField1b>(3, 2),
122			felts!(BinaryField1b[0, 0, 1, 0, 0, 0, 0, 0])
123		);
124		assert_eq!(
125			select_row_evals::<BinaryField1b>(3, 3),
126			felts!(BinaryField1b[0, 0, 0, 1, 0, 0, 0, 0])
127		);
128		assert_eq!(
129			select_row_evals::<BinaryField1b>(3, 4),
130			felts!(BinaryField1b[0, 0, 0, 0, 1, 0, 0, 0])
131		);
132		assert_eq!(
133			select_row_evals::<BinaryField1b>(3, 5),
134			felts!(BinaryField1b[0, 0, 0, 0, 0, 1, 0, 0])
135		);
136		assert_eq!(
137			select_row_evals::<BinaryField1b>(3, 6),
138			felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 1, 0])
139		);
140		assert_eq!(
141			select_row_evals::<BinaryField1b>(3, 7),
142			felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 1])
143		);
144	}
145
146	#[test]
147	fn test_select_row_evals_without_packing() {
148		assert_eq!(
149			select_row_evals::<BinaryField1b>(9, 314),
150			packed_slice::<BinaryField1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
151		);
152		assert_eq!(
153			select_row_evals::<BinaryField1b>(10, 555),
154			packed_slice::<BinaryField1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
155		);
156		assert_eq!(
157			select_row_evals::<BinaryField1b>(11, 1),
158			packed_slice::<BinaryField1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
159		);
160	}
161
162	#[test]
163	fn test_select_row_evals_packing_128() {
164		assert_eq!(
165			select_row_evals::<PackedBinaryField128x1b>(9, 314),
166			packed_slice::<PackedBinaryField128x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
167		);
168		assert_eq!(
169			select_row_evals::<PackedBinaryField128x1b>(10, 555),
170			packed_slice::<PackedBinaryField128x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
171		);
172		assert_eq!(
173			select_row_evals::<PackedBinaryField128x1b>(11, 1),
174			packed_slice::<PackedBinaryField128x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
175		);
176	}
177
178	#[test]
179	fn test_select_row_evals_packing_256() {
180		assert_eq!(
181			select_row_evals::<PackedBinaryField256x1b>(9, 314),
182			packed_slice::<PackedBinaryField256x1b>(&[(0..314, 0), (314..315, 1), (315..512, 0)])
183		);
184		assert_eq!(
185			select_row_evals::<PackedBinaryField256x1b>(10, 555),
186			packed_slice::<PackedBinaryField256x1b>(&[(0..555, 0), (555..556, 1), (556..1024, 0)])
187		);
188		assert_eq!(
189			select_row_evals::<PackedBinaryField256x1b>(11, 1),
190			packed_slice::<PackedBinaryField256x1b>(&[(0..1, 0), (1..2, 1), (2..2048, 0)])
191		);
192	}
193
194	#[test]
195	fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
196		for n_vars in 1..5 {
197			for index in 0..(1 << n_vars) {
198				let select_row = SelectRow::new(n_vars, index).unwrap();
199				assert_eq!(
200					hypercube_evals_from_oracle::<BinaryField1b>(&select_row),
201					select_row
202						.multilinear_extension::<BinaryField1b>()
203						.unwrap()
204						.evals()
205				);
206			}
207		}
208	}
209
210	fn select_row_evals<P>(n_vars: usize, index: usize) -> Vec<P>
211	where
212		P: PackedField<Scalar = BinaryField1b>,
213	{
214		SelectRow::new(n_vars, index)
215			.unwrap()
216			.multilinear_extension::<P>()
217			.unwrap()
218			.evals()
219			.to_vec()
220	}
221}