binius_field/
transpose.rs1use binius_utils::checked_arithmetics::log2_strict_usize;
4
5use super::packed::PackedField;
6use crate::{ExtensionField, Field, PackedExtension};
7
8#[derive(Clone, thiserror::Error, Debug)]
10pub enum Error {
11 #[error("the \"{param}\" argument's size is invalid: {msg}")]
12 InvalidBufferSize { param: &'static str, msg: String },
13 #[error("dimension n of square blocks must divide packing width")]
14 SquareBlockDimensionMustDivideWidth,
15 #[error("destination buffer must be castable to a packed extension field buffer")]
16 UnalignedDestination,
17}
18
19pub fn square_transpose<P: PackedField>(log_n: usize, elems: &mut [P]) -> Result<(), Error> {
31 if P::LOG_WIDTH < log_n {
32 return Err(Error::SquareBlockDimensionMustDivideWidth);
33 }
34
35 let size = elems.len();
36 if !size.is_power_of_two() {
37 return Err(Error::InvalidBufferSize {
38 param: "elems",
39 msg: "power of two size required".to_string(),
40 });
41 }
42 let log_size = log2_strict_usize(size);
43 if log_size < log_n {
44 return Err(Error::InvalidBufferSize {
45 param: "elems",
46 msg: "must have length at least 2^log_n".to_string(),
47 });
48 }
49
50 let log_w = log_size - log_n;
51
52 for i in 0..log_n {
55 for j in 0..1 << (log_n - i - 1) {
56 for k in 0..1 << (log_w + i) {
57 let idx0 = (j << (log_w + i + 1)) | k;
58 let idx1 = idx0 | (1 << (log_w + i));
59
60 let v0 = elems[idx0];
61 let v1 = elems[idx1];
62 let (v0, v1) = v0.interleave(v1, i);
63 elems[idx0] = v0;
64 elems[idx1] = v1;
65 }
66 }
67 }
68
69 Ok(())
70}
71
72pub fn square_transforms_extension_field<F: Field, FE: ExtensionField<F> + PackedExtension<F>>(
73 values: &mut [FE],
74) -> Result<(), Error> {
75 square_transpose(FE::LOG_DEGREE, FE::cast_bases_mut(values))
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::PackedBinaryField128x1b;
82
83 #[test]
84 fn test_square_transpose_128x1b() {
85 let mut elems = [
86 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
87 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
88 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
89 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
90 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
91 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
92 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
93 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
94 ];
95 square_transpose(3, &mut elems).unwrap();
96
97 let expected = [
98 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
99 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
100 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
101 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
102 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
103 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
104 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
105 PackedBinaryField128x1b::from(0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0u128),
106 ];
107 assert_eq!(elems, expected);
108 }
109
110 #[test]
111 fn test_square_transpose_128x1b_multi_row() {
112 let mut elems = [
113 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
114 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
115 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
116 PackedBinaryField128x1b::from(0x00000000000000000000000000000000u128),
117 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
118 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
119 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
120 PackedBinaryField128x1b::from(0xffffffffffffffffffffffffffffffffu128),
121 ];
122 square_transpose(1, &mut elems).unwrap();
123
124 let expected = [
125 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
126 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
127 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
128 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
129 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
130 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
131 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
132 PackedBinaryField128x1b::from(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaau128),
133 ];
134 assert_eq!(elems, expected);
135 }
136}