1use std::{
4 array,
5 ops::{Add, AddAssign, Index, IndexMut},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9
10pub trait TowerLevel: 'static {
22 const WIDTH: usize;
24 const LOG_WIDTH: usize = checked_log_2(Self::WIDTH);
25
26 type Data<T>: AsMut<[T]>
28 + AsRef<[T]>
29 + Sized
30 + Index<usize, Output = T>
31 + IndexMut<usize, Output = T>;
32 type Base: TowerLevel;
33
34 #[allow(clippy::type_complexity)]
36 fn split<T>(
37 data: &Self::Data<T>,
38 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>);
39
40 #[allow(clippy::type_complexity)]
42 fn split_mut<T>(
43 data: &mut Self::Data<T>,
44 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>);
45
46 fn join<T: Copy + Default>(
48 first: &<Self::Base as TowerLevel>::Data<T>,
49 second: &<Self::Base as TowerLevel>::Data<T>,
50 ) -> Self::Data<T>;
51
52 fn from_fn<T: Copy>(f: impl FnMut(usize) -> T) -> Self::Data<T>;
54
55 fn default<T: Copy + Default>() -> Self::Data<T> {
57 Self::from_fn(|_| T::default())
58 }
59}
60
61pub trait TowerLevelWithArithOps: TowerLevel {
62 #[inline(always)]
63 fn add_into<T: AddAssign + Copy>(
64 field_element: &Self::Data<T>,
65 destination: &mut Self::Data<T>,
66 ) {
67 for i in 0..Self::WIDTH {
68 destination[i] += field_element[i];
69 }
70 }
71
72 #[inline(always)]
73 fn copy_into<T: Copy>(field_element: &Self::Data<T>, destination: &mut Self::Data<T>) {
74 for i in 0..Self::WIDTH {
75 destination[i] = field_element[i];
76 }
77 }
78
79 #[inline(always)]
80 fn sum<T: Copy + Add<Output = T>>(
81 field_element_a: &Self::Data<T>,
82 field_element_b: &Self::Data<T>,
83 ) -> Self::Data<T> {
84 Self::from_fn(|i| field_element_a[i] + field_element_b[i])
85 }
86}
87
88impl<T: TowerLevel> TowerLevelWithArithOps for T {}
89
90pub struct TowerLevel64;
91
92impl TowerLevel for TowerLevel64 {
93 const WIDTH: usize = 64;
94
95 type Data<T> = [T; 64];
96 type Base = TowerLevel32;
97
98 #[inline(always)]
99 fn split<T>(
100 data: &Self::Data<T>,
101 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
102 ((data[0..32].try_into().unwrap()), (data[32..64].try_into().unwrap()))
103 }
104
105 #[inline(always)]
106 fn split_mut<T>(
107 data: &mut Self::Data<T>,
108 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
109 let (chunk_1, chunk_2) = data.split_at_mut(32);
110
111 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
112 }
113
114 #[inline(always)]
115 fn join<T: Copy + Default>(
116 left: &<Self::Base as TowerLevel>::Data<T>,
117 right: &<Self::Base as TowerLevel>::Data<T>,
118 ) -> Self::Data<T> {
119 let mut result = [T::default(); 64];
120 result[..32].copy_from_slice(left);
121 result[32..].copy_from_slice(right);
122 result
123 }
124
125 #[inline(always)]
126 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
127 array::from_fn(f)
128 }
129}
130
131pub struct TowerLevel32;
132
133impl TowerLevel for TowerLevel32 {
134 const WIDTH: usize = 32;
135
136 type Data<T> = [T; 32];
137 type Base = TowerLevel16;
138
139 #[inline(always)]
140 fn split<T>(
141 data: &Self::Data<T>,
142 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
143 ((data[0..16].try_into().unwrap()), (data[16..32].try_into().unwrap()))
144 }
145
146 #[inline(always)]
147 fn split_mut<T>(
148 data: &mut Self::Data<T>,
149 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
150 let (chunk_1, chunk_2) = data.split_at_mut(16);
151
152 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
153 }
154
155 #[inline(always)]
156 fn join<T: Copy + Default>(
157 left: &<Self::Base as TowerLevel>::Data<T>,
158 right: &<Self::Base as TowerLevel>::Data<T>,
159 ) -> Self::Data<T> {
160 let mut result = [T::default(); 32];
161 result[..16].copy_from_slice(left);
162 result[16..].copy_from_slice(right);
163 result
164 }
165
166 #[inline(always)]
167 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
168 array::from_fn(f)
169 }
170}
171
172pub struct TowerLevel16;
173
174impl TowerLevel for TowerLevel16 {
175 const WIDTH: usize = 16;
176
177 type Data<T> = [T; 16];
178 type Base = TowerLevel8;
179
180 #[inline(always)]
181 fn split<T>(
182 data: &Self::Data<T>,
183 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
184 ((data[0..8].try_into().unwrap()), (data[8..16].try_into().unwrap()))
185 }
186
187 #[inline(always)]
188 fn split_mut<T>(
189 data: &mut Self::Data<T>,
190 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
191 let (chunk_1, chunk_2) = data.split_at_mut(8);
192
193 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
194 }
195
196 #[inline(always)]
197 fn join<T: Copy + Default>(
198 left: &<Self::Base as TowerLevel>::Data<T>,
199 right: &<Self::Base as TowerLevel>::Data<T>,
200 ) -> Self::Data<T> {
201 let mut result = [T::default(); 16];
202 result[..8].copy_from_slice(left);
203 result[8..].copy_from_slice(right);
204 result
205 }
206
207 #[inline(always)]
208 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
209 array::from_fn(f)
210 }
211}
212
213pub struct TowerLevel8;
214
215impl TowerLevel for TowerLevel8 {
216 const WIDTH: usize = 8;
217
218 type Data<T> = [T; 8];
219 type Base = TowerLevel4;
220
221 #[inline(always)]
222 fn split<T>(
223 data: &Self::Data<T>,
224 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
225 ((data[0..4].try_into().unwrap()), (data[4..8].try_into().unwrap()))
226 }
227
228 #[inline(always)]
229 fn split_mut<T>(
230 data: &mut Self::Data<T>,
231 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
232 let (chunk_1, chunk_2) = data.split_at_mut(4);
233
234 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
235 }
236
237 #[inline(always)]
238 fn join<T: Copy + Default>(
239 left: &<Self::Base as TowerLevel>::Data<T>,
240 right: &<Self::Base as TowerLevel>::Data<T>,
241 ) -> Self::Data<T> {
242 let mut result = [T::default(); 8];
243 result[..4].copy_from_slice(left);
244 result[4..].copy_from_slice(right);
245 result
246 }
247
248 #[inline(always)]
249 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
250 array::from_fn(f)
251 }
252}
253
254pub struct TowerLevel4;
255
256impl TowerLevel for TowerLevel4 {
257 const WIDTH: usize = 4;
258
259 type Data<T> = [T; 4];
260 type Base = TowerLevel2;
261
262 #[inline(always)]
263 fn split<T>(
264 data: &Self::Data<T>,
265 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
266 ((data[0..2].try_into().unwrap()), (data[2..4].try_into().unwrap()))
267 }
268
269 #[inline(always)]
270 fn split_mut<T>(
271 data: &mut Self::Data<T>,
272 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
273 let (chunk_1, chunk_2) = data.split_at_mut(2);
274
275 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
276 }
277
278 #[inline(always)]
279 fn join<T: Copy + Default>(
280 left: &<Self::Base as TowerLevel>::Data<T>,
281 right: &<Self::Base as TowerLevel>::Data<T>,
282 ) -> Self::Data<T> {
283 let mut result = [T::default(); 4];
284 result[..2].copy_from_slice(left);
285 result[2..].copy_from_slice(right);
286 result
287 }
288
289 #[inline(always)]
290 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
291 array::from_fn(f)
292 }
293}
294
295pub struct TowerLevel2;
296
297impl TowerLevel for TowerLevel2 {
298 const WIDTH: usize = 2;
299
300 type Data<T> = [T; 2];
301 type Base = TowerLevel1;
302
303 #[inline(always)]
304 fn split<T>(
305 data: &Self::Data<T>,
306 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
307 ((data[0..1].try_into().unwrap()), (data[1..2].try_into().unwrap()))
308 }
309
310 #[inline(always)]
311 fn split_mut<T>(
312 data: &mut Self::Data<T>,
313 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
314 let (chunk_1, chunk_2) = data.split_at_mut(1);
315
316 ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
317 }
318
319 #[inline(always)]
320 fn join<T: Copy + Default>(
321 left: &<Self::Base as TowerLevel>::Data<T>,
322 right: &<Self::Base as TowerLevel>::Data<T>,
323 ) -> Self::Data<T> {
324 let mut result = [T::default(); 2];
325 result[..1].copy_from_slice(left);
326 result[1..].copy_from_slice(right);
327 result
328 }
329
330 #[inline(always)]
331 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
332 array::from_fn(f)
333 }
334}
335
336pub struct TowerLevel1;
337
338impl TowerLevel for TowerLevel1 {
339 const WIDTH: usize = 1;
340
341 type Data<T> = [T; 1];
342 type Base = Self;
343
344 #[inline(always)]
347 fn split<T>(
348 _data: &Self::Data<T>,
349 ) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
350 unreachable!()
351 }
352
353 #[inline(always)]
354 fn split_mut<T>(
355 _data: &mut Self::Data<T>,
356 ) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
357 unreachable!()
358 }
359
360 #[inline(always)]
361 fn join<T: Copy + Default>(
362 _left: &<Self::Base as TowerLevel>::Data<T>,
363 _right: &<Self::Base as TowerLevel>::Data<T>,
364 ) -> Self::Data<T> {
365 unreachable!()
366 }
367
368 #[inline(always)]
369 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
370 array::from_fn(f)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_split_join() {
380 fn check_split_and_join<TL: TowerLevel>() {
381 let data = TL::from_fn(|i| i as u32);
382 let (left, right) = TL::split(&data);
383 assert_eq!(left.as_ref(), &data.as_ref()[..TL::WIDTH / 2]);
384 assert_eq!(right.as_ref(), &data.as_ref()[TL::WIDTH / 2..]);
385
386 let joined = TL::join(left, right);
387 assert_eq!(joined.as_ref(), data.as_ref());
388 }
389
390 check_split_and_join::<TowerLevel64>();
391 check_split_and_join::<TowerLevel32>();
392 check_split_and_join::<TowerLevel16>();
393 check_split_and_join::<TowerLevel8>();
394 check_split_and_join::<TowerLevel4>();
395 check_split_and_join::<TowerLevel2>();
396 }
397
398 #[test]
399 fn test_split_mut_join() {
400 fn check_mut_and_join<TL: TowerLevel>() {
401 let mut data = TL::from_fn(|i| i as u32);
402 let expected_left = data.as_ref()[..TL::WIDTH / 2].to_vec();
403 let expected_right = data.as_ref()[TL::WIDTH / 2..].to_vec();
404 let (left, right) = TL::split_mut(&mut data);
405
406 assert_eq!(left.as_mut(), expected_left);
407 assert_eq!(right.as_mut(), expected_right);
408
409 let joined = TL::join(left, right);
410 assert_eq!(joined.as_ref(), data.as_ref());
411 }
412
413 check_mut_and_join::<TowerLevel64>();
414 check_mut_and_join::<TowerLevel32>();
415 check_mut_and_join::<TowerLevel16>();
416 check_mut_and_join::<TowerLevel8>();
417 check_mut_and_join::<TowerLevel4>();
418 check_mut_and_join::<TowerLevel2>();
419 }
420
421 #[test]
422 fn test_from_fn() {
423 fn check_from_fn<TL: TowerLevel>() {
424 let data = TL::from_fn(|i| i as u32);
425 let expected = (0..TL::WIDTH).map(|i| i as u32).collect::<Vec<_>>();
426 assert_eq!(data.as_ref(), expected.as_slice());
427 }
428
429 check_from_fn::<TowerLevel64>();
430 check_from_fn::<TowerLevel32>();
431 check_from_fn::<TowerLevel16>();
432 check_from_fn::<TowerLevel8>();
433 check_from_fn::<TowerLevel4>();
434 check_from_fn::<TowerLevel2>();
435 }
436
437 #[test]
438 fn test_default() {
439 fn check_default<TL: TowerLevel>() {
440 let data = TL::default::<u32>();
441 let expected = vec![0u32; TL::WIDTH];
442 assert_eq!(data.as_ref(), expected);
443 }
444
445 check_default::<TowerLevel64>();
446 check_default::<TowerLevel32>();
447 check_default::<TowerLevel16>();
448 check_default::<TowerLevel8>();
449 check_default::<TowerLevel4>();
450 check_default::<TowerLevel2>();
451 }
452
453 #[test]
454 fn test_add_into() {
455 fn check_add_into<TL: TowerLevel>() {
456 let a = TL::from_fn(|i| i as u32);
457 let mut b = TL::default::<u32>();
458 TL::add_into(&a, &mut b);
459 assert_eq!(b.as_ref(), a.as_ref());
460 }
461
462 check_add_into::<TowerLevel64>();
463 check_add_into::<TowerLevel32>();
464 check_add_into::<TowerLevel16>();
465 check_add_into::<TowerLevel8>();
466 check_add_into::<TowerLevel4>();
467 check_add_into::<TowerLevel2>();
468 }
469
470 #[test]
471 fn test_copy_into() {
472 fn check_copy_into<TL: TowerLevel>() {
473 let a = TL::from_fn(|i| i as u32);
474 let mut b = TL::default::<u32>();
475 TL::copy_into(&a, &mut b);
476 assert_eq!(b.as_ref(), a.as_ref());
477 }
478
479 check_copy_into::<TowerLevel64>();
480 check_copy_into::<TowerLevel32>();
481 check_copy_into::<TowerLevel16>();
482 check_copy_into::<TowerLevel8>();
483 check_copy_into::<TowerLevel4>();
484 check_copy_into::<TowerLevel2>();
485 }
486
487 #[test]
488 fn test_sum() {
489 fn check_sum<TL: TowerLevel>() {
490 let a = TL::from_fn(|i| i as u32);
491 let b = TL::from_fn(|i| i as u32);
492 let sum_result = TL::sum(&a, &b);
493 let expected = TL::from_fn(|i| 2 * (i as u32));
494 assert_eq!(sum_result.as_ref(), expected.as_ref());
495 }
496
497 check_sum::<TowerLevel64>();
498 check_sum::<TowerLevel32>();
499 check_sum::<TowerLevel16>();
500 check_sum::<TowerLevel8>();
501 check_sum::<TowerLevel4>();
502 check_sum::<TowerLevel2>();
503 }
504}