binius_field/
transpose.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use binius_utils::checked_arithmetics::log2_strict_usize;
4
5use super::packed::PackedField;
6use crate::{ExtensionField, Field, PackedExtension};
7
8/// Error thrown when a transpose operation fails.
9#[derive(Clone, thiserror::Error, Debug)]
10pub enum Error {
11	#[error("the \"{param}\" argument's size is invalid: {msg}")]
12	InvalidBufferSize { param: &'static str, msg: String },
13	#[error("dimension n of square blocks must divide packing width")]
14	SquareBlockDimensionMustDivideWidth,
15	#[error("destination buffer must be castable to a packed extension field buffer")]
16	UnalignedDestination,
17}
18
19/// Transpose square blocks of elements within packed field elements in place.
20///
21/// The input elements are interpreted as a rectangular matrix with height `n = 2^n` in row-major
22/// order. This matrix is interpreted as a vector of square matrices of field elements, and each
23/// square matrix is transposed in-place.
24///
25/// # Arguments
26///
27/// * `log_n`: The base-2 logarithm of the dimension of the n x n square matrix. Must be less than
28///   or equal to the base-2 logarithm of the packing width.
29/// * `elems`: The packed field elements, length is a power-of-two multiple of `1 << log_n`.
30pub fn square_transpose<P: PackedField>(log_n: usize, elems: &mut [P]) -> Result<(), Error> {
31	if P::LOG_WIDTH < log_n {
32		return Err(Error::SquareBlockDimensionMustDivideWidth);
33	}
34
35	let size = elems.len();
36	if !size.is_power_of_two() {
37		return Err(Error::InvalidBufferSize {
38			param: "elems",
39			msg: "power of two size required".to_string(),
40		});
41	}
42	let log_size = log2_strict_usize(size);
43	if log_size < log_n {
44		return Err(Error::InvalidBufferSize {
45			param: "elems",
46			msg: "must have length at least 2^log_n".to_string(),
47		});
48	}
49
50	let log_w = log_size - log_n;
51
52	// See Hacker's Delight, Section 7-3.
53	// https://dl.acm.org/doi/10.5555/2462741
54	for i in 0..log_n {
55		for j in 0..1 << (log_n - i - 1) {
56			for k in 0..1 << (log_w + i) {
57				let idx0 = (j << (log_w + i + 1)) | k;
58				let idx1 = idx0 | (1 << (log_w + i));
59
60				let v0 = elems[idx0];
61				let v1 = elems[idx1];
62				let (v0, v1) = v0.interleave(v1, i);
63				elems[idx0] = v0;
64				elems[idx1] = v1;
65			}
66		}
67	}
68
69	Ok(())
70}
71
72pub fn square_transforms_extension_field<F: Field, FE: ExtensionField<F> + PackedExtension<F>>(
73	values: &mut [FE],
74) -> Result<(), Error> {
75	square_transpose(FE::LOG_DEGREE, FE::cast_bases_mut(values))
76}
77
78#[cfg(test)]
79mod tests {
80	use super::*;
81	use crate::PackedBinaryField128x1b;
82
83	#[test]
84	fn test_square_transpose_128x1b() {
85		let mut elems = [
86			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
87			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
88			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
89			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
90			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
91			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
92			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
93			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
94		];
95		square_transpose(3, &mut elems).unwrap();
96
97		let expected = [
98			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
99			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
100			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
101			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
102			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
103			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
104			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
105			PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
106		];
107		assert_eq!(elems, expected);
108	}
109
110	#[test]
111	fn test_square_transpose_128x1b_multi_row() {
112		let mut elems = [
113			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
114			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
115			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
116			PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
117			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
118			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
119			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
120			PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
121		];
122		square_transpose(1, &mut elems).unwrap();
123
124		let expected = [
125			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
126			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
127			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
128			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
129			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
130			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
131			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
132			PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
133		];
134		assert_eq!(elems, expected);
135	}
136}