binius_field/
transpose.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use binius_utils::checked_arithmetics::log2_strict_usize;
4
5use super::{packed::PackedField, Field, PackedFieldIndexable, RepackedExtension};
6
7/// Error thrown when a transpose operation fails.
8#[derive(Clone, thiserror::Error, Debug)]
9pub enum Error {
10	#[error("the \"{param}\" argument's size is invalid: {msg}")]
11	InvalidBufferSize { param: &'static str, msg: String },
12	#[error("dimension n of square blocks must divide packing width")]
13	SquareBlockDimensionMustDivideWidth,
14	#[error("destination buffer must be castable to a packed extension field buffer")]
15	UnalignedDestination,
16}
17
18/// Transpose square blocks of elements within packed field elements in place.
19///
20/// The input elements are interpreted as a rectangular matrix with height `n = 2^n` in row-major
21/// order. This matrix is interpreted as a vector of square matrices of field elements, and each
22/// square matrix is transposed in-place.
23///
24/// # Arguments
25///
26/// * `log_n`: The base-2 logarithm of the dimension of the n x n square matrix. Must be less than
27///   or equal to the base-2 logarithm of the packing width.
28/// * `elems`: The packed field elements, length is a power-of-two multiple of `1 << log_n`.
29pub fn square_transpose<P: PackedField>(log_n: usize, elems: &mut [P]) -> Result<(), Error> {
30	if P::LOG_WIDTH < log_n {
31		return Err(Error::SquareBlockDimensionMustDivideWidth);
32	}
33
34	let size = elems.len();
35	if !size.is_power_of_two() {
36		return Err(Error::InvalidBufferSize {
37			param: "elems",
38			msg: "power of two size required".to_string(),
39		});
40	}
41	let log_size = log2_strict_usize(size);
42	if log_size < log_n {
43		return Err(Error::InvalidBufferSize {
44			param: "elems",
45			msg: "must have length at least 2^log_n".to_string(),
46		});
47	}
48
49	let log_w = log_size - log_n;
50
51	// See Hacker's Delight, Section 7-3.
52	// https://dl.acm.org/doi/10.5555/2462741
53	for i in 0..log_n {
54		for j in 0..1 << (log_n - i - 1) {
55			for k in 0..1 << (log_w + i) {
56				let idx0 = (j << (log_w + i + 1)) | k;
57				let idx1 = idx0 | (1 << (log_w + i));
58
59				let v0 = elems[idx0];
60				let v1 = elems[idx1];
61				let (v0, v1) = v0.interleave(v1, i);
62				elems[idx0] = v0;
63				elems[idx1] = v1;
64			}
65		}
66	}
67
68	Ok(())
69}
70
71/// Transpose the scalars within a slice of packed extension field elements.
72///
73/// The `src` buffer is vector of `n` field extension field elements, or alternatively viewed as an
74/// n x d matrix of base field elements, where `d` is the extension degree. This transposes the
75/// base field elements into a d x n matrix in row-major order.
76pub fn transpose_scalars<P, FE, PE>(src: &[PE], dst: &mut [P]) -> Result<(), Error>
77where
78	P: PackedField,
79	FE: Field,
80	PE: PackedFieldIndexable<Scalar = FE> + RepackedExtension<P>,
81{
82	let len = src.len();
83	if !len.is_power_of_two() {
84		return Err(Error::InvalidBufferSize {
85			param: "elems",
86			msg: "power of two size required".to_string(),
87		});
88	}
89	if dst.len() != len {
90		return Err(Error::InvalidBufferSize {
91			param: "dst",
92			msg: "must have equal length to src buffer".to_string(),
93		});
94	}
95
96	let log_d = FE::LOG_DEGREE;
97	let log_n = log2_strict_usize(src.len()) + PE::LOG_WIDTH;
98
99	if log_n < log_d {
100		return Err(Error::InvalidBufferSize {
101			param: "src",
102			msg: "must have length at least 2^{d - w} where d is the extension degree and w is \
103			the extension packing width"
104				.to_string(),
105		});
106	}
107
108	{
109		let dst_ext = PE::cast_exts_mut(dst);
110		transpose::transpose(
111			PE::unpack_scalars(src),
112			PE::unpack_scalars_mut(dst_ext),
113			1 << log_d,
114			1 << (log_n - log_d),
115		);
116	}
117	square_transpose(log_d, dst)
118}
119
120#[cfg(test)]
121mod tests {
122	use super::*;
123	use crate::{
124		BinaryField32b, PackedBinaryField128x1b, PackedBinaryField16x8b, PackedBinaryField4x32b,
125		PackedBinaryField64x2b, PackedExtension,
126	};
127
128	#[test]
129	fn test_square_transpose_128x1b() {
130		let mut elems = [
131			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
132			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
133			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
134			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
135			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
136			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
137			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
138			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
139		];
140		square_transpose(3, &mut elems).unwrap();
141
142		let expected = [
143			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
144			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
145			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
146			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
147			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
148			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
149			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
150			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
151		];
152		assert_eq!(elems, expected);
153	}
154
155	#[test]
156	fn test_square_transpose_128x1b_multi_row() {
157		let mut elems = [
158			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
159			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
160			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
161			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
162			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
163			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
164			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
165			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
166		];
167		square_transpose(1, &mut elems).unwrap();
168
169		let expected = [
170			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
171			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
172			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
173			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
174			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
175			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
176			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
177			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
178		];
179		assert_eq!(elems, expected);
180	}
181
182	#[test]
183	fn test_square_transpose_64x2b() {
184		let mut elems = [
185			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
186			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
187			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
188			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
189			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
190			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
191			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
192			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
193		];
194		square_transpose(3, &mut elems).unwrap();
195
196		let expected = [
197			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
198			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
199			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
200			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
201			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
202			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
203			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
204			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
205		]
206		.map(PackedBinaryField64x2b::from);
207		assert_eq!(elems, expected);
208	}
209
210	#[test]
211	#[rustfmt::skip]
212	fn test_transpose_scalars() {
213		let elems = [
214			[
215				0x03020100,
216				0x07060504,
217				0x0b0a0908,
218				0x0f0e0d0c,
219			],
220			[
221				0x13121110,
222				0x17161514,
223				0x1b1a1918,
224				0x1f1e1d1c,
225			],
226			[
227				0x23222120,
228				0x27262524,
229				0x2b2a2928,
230				0x2f2e2d2c,
231			],
232			[
233				0x33323130,
234				0x37363534,
235				0x3b3a3938,
236				0x3f3e3d3c,
237			],
238			[
239				0x43424140,
240				0x47464544,
241				0x4b4a4948,
242				0x4f4e4d4c,
243			],
244			[
245				0x53525150,
246				0x57565554,
247				0x5b5a5958,
248				0x5f5e5d5c,
249			],
250			[
251				0x63626160,
252				0x67666564,
253				0x6b6a6968,
254				0x6f6e6d6c,
255			],
256			[
257				0x73727170,
258				0x77767574,
259				0x7b7a7978,
260				0x7f7e7d7c,
261			],
262		].map(|vals| PackedBinaryField4x32b::from_scalars(vals.map(BinaryField32b::new)));
263
264		let expected = [
265			[0x0c080400, 0x1c181410, 0x2c282420, 0x3c383430],
266			[0x4c484440, 0x5c585450, 0x6c686460, 0x7c787470],
267
268			[0x0d090501, 0x1d191511, 0x2d292521, 0x3d393531],
269			[0x4d494541, 0x5d595551, 0x6d696561, 0x7d797571],
270
271			[0x0e0a0602, 0x1e1a1612, 0x2e2a2622, 0x3e3a3632],
272			[0x4e4a4642, 0x5e5a5652, 0x6e6a6662, 0x7e7a7672],
273
274			[0x0f0b0703, 0x1f1b1713, 0x2f2b2723, 0x3f3b3733],
275			[0x4f4b4743, 0x5f5b5753, 0x6f6b6763, 0x7f7b7773],
276		].map(|vals| PackedBinaryField4x32b::from_scalars(vals.map(BinaryField32b::new)));
277
278		let mut dst = [PackedBinaryField4x32b::default(); 8];
279		transpose_scalars::<PackedBinaryField16x8b,_,_>(&elems, PackedBinaryField4x32b::cast_bases_mut(&mut dst)).unwrap();
280		assert_eq!(dst, expected);
281	}
282}