1use binius_utils::checked_arithmetics::log2_strict_usize;
4
5use super::{packed::PackedField, Field, PackedFieldIndexable, RepackedExtension};
6
7#[derive(Clone, thiserror::Error, Debug)]
9pub enum Error {
10 #[error("the \"{param}\" argument's size is invalid: {msg}")]
11 InvalidBufferSize { param: &'static str, msg: String },
12 #[error("dimension n of square blocks must divide packing width")]
13 SquareBlockDimensionMustDivideWidth,
14 #[error("destination buffer must be castable to a packed extension field buffer")]
15 UnalignedDestination,
16}
17
18pub fn square_transpose<P: PackedField>(log_n: usize, elems: &mut [P]) -> Result<(), Error> {
30 if P::LOG_WIDTH < log_n {
31 return Err(Error::SquareBlockDimensionMustDivideWidth);
32 }
33
34 let size = elems.len();
35 if !size.is_power_of_two() {
36 return Err(Error::InvalidBufferSize {
37 param: "elems",
38 msg: "power of two size required".to_string(),
39 });
40 }
41 let log_size = log2_strict_usize(size);
42 if log_size < log_n {
43 return Err(Error::InvalidBufferSize {
44 param: "elems",
45 msg: "must have length at least 2^log_n".to_string(),
46 });
47 }
48
49 let log_w = log_size - log_n;
50
51 for i in 0..log_n {
54 for j in 0..1 << (log_n - i - 1) {
55 for k in 0..1 << (log_w + i) {
56 let idx0 = (j << (log_w + i + 1)) | k;
57 let idx1 = idx0 | (1 << (log_w + i));
58
59 let v0 = elems[idx0];
60 let v1 = elems[idx1];
61 let (v0, v1) = v0.interleave(v1, i);
62 elems[idx0] = v0;
63 elems[idx1] = v1;
64 }
65 }
66 }
67
68 Ok(())
69}
70
71pub fn transpose_scalars<P, FE, PE>(src: &[PE], dst: &mut [P]) -> Result<(), Error>
77where
78 P: PackedField,
79 FE: Field,
80 PE: PackedFieldIndexable<Scalar = FE> + RepackedExtension<P>,
81{
82 let len = src.len();
83 if !len.is_power_of_two() {
84 return Err(Error::InvalidBufferSize {
85 param: "elems",
86 msg: "power of two size required".to_string(),
87 });
88 }
89 if dst.len() != len {
90 return Err(Error::InvalidBufferSize {
91 param: "dst",
92 msg: "must have equal length to src buffer".to_string(),
93 });
94 }
95
96 let log_d = FE::LOG_DEGREE;
97 let log_n = log2_strict_usize(src.len()) + PE::LOG_WIDTH;
98
99 if log_n < log_d {
100 return Err(Error::InvalidBufferSize {
101 param: "src",
102 msg: "must have length at least 2^{d - w} where d is the extension degree and w is \
103 the extension packing width"
104 .to_string(),
105 });
106 }
107
108 {
109 let dst_ext = PE::cast_exts_mut(dst);
110 transpose::transpose(
111 PE::unpack_scalars(src),
112 PE::unpack_scalars_mut(dst_ext),
113 1 << log_d,
114 1 << (log_n - log_d),
115 );
116 }
117 square_transpose(log_d, dst)
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::{
124 BinaryField32b, PackedBinaryField128x1b, PackedBinaryField16x8b, PackedBinaryField4x32b,
125 PackedBinaryField64x2b, PackedExtension,
126 };
127
128 #[test]
129 fn test_square_transpose_128x1b() {
130 let mut elems = [
131 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
132 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
133 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
134 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
135 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
136 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
137 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
138 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
139 ];
140 square_transpose(3, &mut elems).unwrap();
141
142 let expected = [
143 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
144 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
145 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
146 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
147 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
148 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
149 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
150 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
151 ];
152 assert_eq!(elems, expected);
153 }
154
155 #[test]
156 fn test_square_transpose_128x1b_multi_row() {
157 let mut elems = [
158 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
159 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
160 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
161 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
162 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
163 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
164 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
165 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
166 ];
167 square_transpose(1, &mut elems).unwrap();
168
169 let expected = [
170 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
171 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
172 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
173 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
174 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
175 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
176 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
177 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
178 ];
179 assert_eq!(elems, expected);
180 }
181
182 #[test]
183 fn test_square_transpose_64x2b() {
184 let mut elems = [
185 PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
186 PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
187 PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
188 PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
189 PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
190 PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
191 PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
192 PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
193 ];
194 square_transpose(3, &mut elems).unwrap();
195
196 let expected = [
197 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
198 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
199 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
200 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
201 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
202 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
203 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
204 0xff00ff00ff00ff00ff00ff00ff00ff00u128,
205 ]
206 .map(PackedBinaryField64x2b::from);
207 assert_eq!(elems, expected);
208 }
209
210 #[test]
211 #[rustfmt::skip]
212 fn test_transpose_scalars() {
213 let elems = [
214 [
215 0x03020100,
216 0x07060504,
217 0x0b0a0908,
218 0x0f0e0d0c,
219 ],
220 [
221 0x13121110,
222 0x17161514,
223 0x1b1a1918,
224 0x1f1e1d1c,
225 ],
226 [
227 0x23222120,
228 0x27262524,
229 0x2b2a2928,
230 0x2f2e2d2c,
231 ],
232 [
233 0x33323130,
234 0x37363534,
235 0x3b3a3938,
236 0x3f3e3d3c,
237 ],
238 [
239 0x43424140,
240 0x47464544,
241 0x4b4a4948,
242 0x4f4e4d4c,
243 ],
244 [
245 0x53525150,
246 0x57565554,
247 0x5b5a5958,
248 0x5f5e5d5c,
249 ],
250 [
251 0x63626160,
252 0x67666564,
253 0x6b6a6968,
254 0x6f6e6d6c,
255 ],
256 [
257 0x73727170,
258 0x77767574,
259 0x7b7a7978,
260 0x7f7e7d7c,
261 ],
262 ].map(|vals| PackedBinaryField4x32b::from_scalars(vals.map(BinaryField32b::new)));
263
264 let expected = [
265 [0x0c080400, 0x1c181410, 0x2c282420, 0x3c383430],
266 [0x4c484440, 0x5c585450, 0x6c686460, 0x7c787470],
267
268 [0x0d090501, 0x1d191511, 0x2d292521, 0x3d393531],
269 [0x4d494541, 0x5d595551, 0x6d696561, 0x7d797571],
270
271 [0x0e0a0602, 0x1e1a1612, 0x2e2a2622, 0x3e3a3632],
272 [0x4e4a4642, 0x5e5a5652, 0x6e6a6662, 0x7e7a7672],
273
274 [0x0f0b0703, 0x1f1b1713, 0x2f2b2723, 0x3f3b3733],
275 [0x4f4b4743, 0x5f5b5753, 0x6f6b6763, 0x7f7b7773],
276 ].map(|vals| PackedBinaryField4x32b::from_scalars(vals.map(BinaryField32b::new)));
277
278 let mut dst = [PackedBinaryField4x32b::default(); 8];
279 transpose_scalars::<PackedBinaryField16x8b,_,_>(&elems, PackedBinaryField4x32b::cast_bases_mut(&mut dst)).unwrap();
280 assert_eq!(dst, expected);
281 }
282}