binius_field/
transpose.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use binius_utils::checked_arithmetics::log2_strict_usize;
4
5use super::packed::PackedField;
6
7/// Error thrown when a transpose operation fails.
8#[derive(Clone, thiserror::Error, Debug)]
9pub enum Error {
10	#[error("the \"{param}\" argument's size is invalid: {msg}")]
11	InvalidBufferSize { param: &'static str, msg: String },
12	#[error("dimension n of square blocks must divide packing width")]
13	SquareBlockDimensionMustDivideWidth,
14	#[error("destination buffer must be castable to a packed extension field buffer")]
15	UnalignedDestination,
16}
17
18/// Transpose square blocks of elements within packed field elements in place.
19///
20/// The input elements are interpreted as a rectangular matrix with height `n = 2^n` in row-major
21/// order. This matrix is interpreted as a vector of square matrices of field elements, and each
22/// square matrix is transposed in-place.
23///
24/// # Arguments
25///
26/// * `log_n`: The base-2 logarithm of the dimension of the n x n square matrix. Must be less than
27///   or equal to the base-2 logarithm of the packing width.
28/// * `elems`: The packed field elements, length is a power-of-two multiple of `1 << log_n`.
29pub fn square_transpose<P: PackedField>(log_n: usize, elems: &mut [P]) -> Result<(), Error> {
30	if P::LOG_WIDTH < log_n {
31		return Err(Error::SquareBlockDimensionMustDivideWidth);
32	}
33
34	let size = elems.len();
35	if !size.is_power_of_two() {
36		return Err(Error::InvalidBufferSize {
37			param: "elems",
38			msg: "power of two size required".to_string(),
39		});
40	}
41	let log_size = log2_strict_usize(size);
42	if log_size < log_n {
43		return Err(Error::InvalidBufferSize {
44			param: "elems",
45			msg: "must have length at least 2^log_n".to_string(),
46		});
47	}
48
49	let log_w = log_size - log_n;
50
51	// See Hacker's Delight, Section 7-3.
52	// https://dl.acm.org/doi/10.5555/2462741
53	for i in 0..log_n {
54		for j in 0..1 << (log_n - i - 1) {
55			for k in 0..1 << (log_w + i) {
56				let idx0 = (j << (log_w + i + 1)) | k;
57				let idx1 = idx0 | (1 << (log_w + i));
58
59				let v0 = elems[idx0];
60				let v1 = elems[idx1];
61				let (v0, v1) = v0.interleave(v1, i);
62				elems[idx0] = v0;
63				elems[idx1] = v1;
64			}
65		}
66	}
67
68	Ok(())
69}
70
71#[cfg(test)]
72mod tests {
73	use super::*;
74	use crate::{PackedBinaryField128x1b, PackedBinaryField64x2b};
75
76	#[test]
77	fn test_square_transpose_128x1b() {
78		let mut elems = [
79			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
80			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
81			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
82			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
83			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
84			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
85			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
86			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
87		];
88		square_transpose(3, &mut elems).unwrap();
89
90		let expected = [
91			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
92			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
93			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
94			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
95			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
96			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
97			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
98			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
99		];
100		assert_eq!(elems, expected);
101	}
102
103	#[test]
104	fn test_square_transpose_128x1b_multi_row() {
105		let mut elems = [
106			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
107			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
108			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
109			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
110			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
111			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
112			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
113			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
114		];
115		square_transpose(1, &mut elems).unwrap();
116
117		let expected = [
118			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
119			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
120			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
121			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
122			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
123			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
124			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
125			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
126		];
127		assert_eq!(elems, expected);
128	}
129
130	#[test]
131	fn test_square_transpose_64x2b() {
132		let mut elems = [
133			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
134			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
135			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
136			PackedBinaryField64x2b::from(0x00000000000000000000000000000000u128),
137			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
138			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
139			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
140			PackedBinaryField64x2b::from(0xffffffffffffffffffffffffffffffffu128),
141		];
142		square_transpose(3, &mut elems).unwrap();
143
144		let expected = [
145			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
146			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
147			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
148			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
149			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
150			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
151			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
152			0xff00ff00ff00ff00ff00ff00ff00ff00u128,
153		]
154		.map(PackedBinaryField64x2b::from);
155		assert_eq!(elems, expected);
156	}
157}