1use std::{
5 fmt,
6 ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr},
7};
8
9use binius_utils::serialization::{DeserializeBytes, SerializationError, SerializeBytes};
10use bytes::{Buf, BufMut};
11
12#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct Word(pub u64);
16
17impl Word {
18 pub const ZERO: Word = Word(0);
20 pub const ONE: Word = Word(1);
22 pub const ALL_ONE: Word = Word(u64::MAX);
24 pub const MASK_32: Word = Word(0x00000000FFFFFFFF);
26 pub const MSB_ONE: Word = Word(0x8000000000000000);
30}
31
32impl fmt::Debug for Word {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, "Word({:#018x})", self.0)
35 }
36}
37
38impl BitAnd for Word {
39 type Output = Self;
40
41 fn bitand(self, rhs: Self) -> Self::Output {
42 Word(self.0 & rhs.0)
43 }
44}
45
46impl BitOr for Word {
47 type Output = Self;
48
49 fn bitor(self, rhs: Self) -> Self::Output {
50 Word(self.0 | rhs.0)
51 }
52}
53
54impl BitXor for Word {
55 type Output = Self;
56
57 fn bitxor(self, rhs: Self) -> Self::Output {
58 Word(self.0 ^ rhs.0)
59 }
60}
61
62impl Shl<u32> for Word {
63 type Output = Self;
64
65 fn shl(self, rhs: u32) -> Self::Output {
66 Word(self.0 << rhs)
67 }
68}
69
70impl Shr<u32> for Word {
71 type Output = Self;
72
73 fn shr(self, rhs: u32) -> Self::Output {
74 Word(self.0 >> rhs)
75 }
76}
77
78impl Not for Word {
79 type Output = Self;
80
81 fn not(self) -> Self::Output {
82 Word(!self.0)
83 }
84}
85
86impl Word {
87 pub fn from_u64(value: u64) -> Word {
89 Word(value)
90 }
91
92 pub fn iadd_cout_32(self, rhs: Word) -> (Word, Word) {
97 let Word(lhs) = self;
98 let Word(rhs) = rhs;
99 let full_sum = lhs.wrapping_add(rhs);
100 let sum = full_sum & 0x00000000_FFFFFFFF;
101 let cout = (lhs & rhs) | ((lhs ^ rhs) & !full_sum);
102 (Word(sum), Word(cout))
103 }
104
105 pub fn iadd_cin_cout(self, rhs: Word, cin: Word) -> (Word, Word) {
113 debug_assert!(cin == Word::ZERO || cin == Word::ONE, "cin must be 0 or 1");
114 let Word(lhs) = self;
115 let Word(rhs) = rhs;
116 let Word(cin) = cin;
117 let sum = lhs.wrapping_add(rhs).wrapping_add(cin);
118 let cout = (lhs & rhs) | ((lhs ^ rhs) & !sum);
119 (Word(sum), Word(cout))
120 }
121
122 pub fn isub_bin_bout(self, rhs: Word, bin: Word) -> (Word, Word) {
130 debug_assert!(bin == Word::ZERO || bin == Word::ONE, "bin must be 0 or 1");
131 let Word(lhs) = self;
132 let Word(rhs) = rhs;
133 let Word(bin) = bin;
134 let diff = lhs.wrapping_sub(rhs).wrapping_sub(bin);
135 let bout = (!lhs & rhs) | (!(lhs ^ rhs) & diff);
136 (Word(diff), Word(bout))
137 }
138
139 pub fn shr_32(self, n: u32) -> Word {
141 let Word(value) = self;
142 let result = (value >> n) & 0x00000000_FFFFFFFF;
144 Word(result)
145 }
146
147 pub fn sar(&self, n: u32) -> Word {
151 let Word(value) = self;
152 let value = *value as i64;
153 let result = value >> n;
154 Word(result as u64)
155 }
156
157 pub fn rotr_32(self, n: u32) -> Word {
159 let Word(value) = self;
160 let n = n % 32; let value_32 = value & 0x00000000_FFFFFFFF;
163 if n == 0 {
164 return Word(value_32); }
166 let result = ((value_32 >> n) | (value_32 << (32 - n))) & 0x00000000_FFFFFFFF;
168 Word(result)
169 }
170
171 pub fn rotr(self, n: u32) -> Word {
173 let Word(value) = self;
174 let n = n % 64; if n == 0 {
176 return self; }
178 let result = (value << (64 - n)) | (value >> n);
179 Word(result)
180 }
181
182 pub fn imul(self, rhs: Word) -> (Word, Word) {
187 let Word(lhs) = self;
188 let Word(rhs) = rhs;
189 let result = (lhs as u128) * (rhs as u128);
190
191 let hi = (result >> 64) as u64;
192 let lo = (result & 0x0000000000000000_FFFFFFFFFFFFFFFF) as u64;
193 (Word(hi), Word(lo))
194 }
195
196 pub fn smul(self, rhs: Word) -> (Word, Word) {
201 let Word(lhs) = self;
202 let Word(rhs) = rhs;
203 let a = lhs as i64;
205 let b = rhs as i64;
206 let result = (a as i128) * (b as i128);
208 let hi = (result >> 64) as u64;
210 let lo = result as u64;
211 (Word(hi), Word(lo))
212 }
213
214 pub fn wrapping_add(self, rhs: Word) -> Word {
218 Word(self.0.wrapping_add(rhs.0))
219 }
220
221 pub fn wrapping_sub(self, rhs: Word) -> Word {
225 Word(self.0.wrapping_sub(rhs.0))
226 }
227
228 pub fn as_u64(self) -> u64 {
230 self.0
231 }
232
233 pub fn is_msb_true(self) -> bool {
240 (self.0 & 0x8000000000000000) != 0
241 }
242
243 pub fn is_msb_false(self) -> bool {
250 (self.0 & 0x8000000000000000) == 0
251 }
252}
253
254impl SerializeBytes for Word {
255 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
256 self.0.serialize(write_buf)
257 }
258}
259
260impl DeserializeBytes for Word {
261 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
262 where
263 Self: Sized,
264 {
265 Ok(Word(u64::deserialize(read_buf)?))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use proptest::prelude::*;
272
273 use super::*;
274
275 #[test]
276 fn test_constants() {
277 assert_eq!(Word::ZERO, Word(0));
278 assert_eq!(Word::ONE, Word(1));
279 assert_eq!(Word::ALL_ONE, Word(0xFFFFFFFFFFFFFFFF));
280 assert_eq!(Word::MASK_32, Word(0x00000000FFFFFFFF));
281 assert_eq!(Word::MSB_ONE, Word(0x8000000000000000));
282 }
283
284 #[test]
285 fn test_msb_bool() {
286 assert!(Word::MSB_ONE.is_msb_true());
288 assert!(!Word::MSB_ONE.is_msb_false());
289
290 assert!(!Word::ZERO.is_msb_true());
292 assert!(Word::ZERO.is_msb_false());
293
294 assert!(Word(0x8000000000000000).is_msb_true());
296 assert!(Word(0x8000000000000001).is_msb_true());
297 assert!(Word(0x80000000FFFFFFFF).is_msb_true());
298 assert!(Word(0xFFFFFFFFFFFFFFFF).is_msb_true());
299
300 assert!(Word(0x7FFFFFFFFFFFFFFF).is_msb_false());
302 assert!(Word(0x0000000000000001).is_msb_false());
303 assert!(Word(0x00000000FFFFFFFF).is_msb_false());
304 assert!(Word(0x7000000000000000).is_msb_false());
305
306 let test_word = Word(0x8123456789ABCDEF);
308 assert!(test_word.is_msb_true());
309 assert!(!test_word.is_msb_false());
310
311 let test_word2 = Word(0x7123456789ABCDEF);
312 assert!(!test_word2.is_msb_true());
313 assert!(test_word2.is_msb_false());
314 }
315
316 proptest! {
317 #[test]
318 fn prop_msb_bool(val in any::<u64>()) {
319 let word = Word(val);
320
321 assert_eq!(word.is_msb_true(), !word.is_msb_false());
323 assert_eq!(word.is_msb_false(), !word.is_msb_true());
324
325 let msb_set = (val & 0x8000000000000000) != 0;
327 assert_eq!(word.is_msb_true(), msb_set);
328 assert_eq!(word.is_msb_false(), !msb_set);
329
330 let word_with_msb = Word(val | 0x8000000000000000);
332 let word_without_msb = Word(val & 0x7FFFFFFFFFFFFFFF);
333 assert!(word_with_msb.is_msb_true());
334 assert!(word_without_msb.is_msb_false());
335 }
336
337 #[test]
338 fn prop_bitwise_and(a in any::<u64>(), b in any::<u64>()) {
339 let wa = Word(a);
340 let wb = Word(b);
341
342 assert_eq!((wa & wb).0, a & b);
344 assert_eq!(wa & Word::ALL_ONE, wa);
345 assert_eq!(wa & Word::ZERO, Word::ZERO);
346 assert_eq!(wa & wa, wa); assert_eq!(wa & wb, wb & wa);
350 }
351
352 #[test]
353 fn prop_bitwise_or(a in any::<u64>(), b in any::<u64>()) {
354 let wa = Word(a);
355 let wb = Word(b);
356
357 assert_eq!((wa | wb).0, a | b);
359 assert_eq!(wa | Word::ZERO, wa);
360 assert_eq!(wa | Word::ALL_ONE, Word::ALL_ONE);
361 assert_eq!(wa | wa, wa); assert_eq!(wa | wb, wb | wa);
365 }
366
367 #[test]
368 fn prop_bitwise_xor(a in any::<u64>(), b in any::<u64>()) {
369 let wa = Word(a);
370 let wb = Word(b);
371
372 assert_eq!((wa ^ wb).0, a ^ b);
374 assert_eq!(wa ^ Word::ZERO, wa);
375 assert_eq!(wa ^ wa, Word::ZERO);
376 assert_eq!(wa ^ Word::ALL_ONE, !wa);
377
378 assert_eq!(wa ^ wb, wb ^ wa);
380
381 assert_eq!(wa ^ wb ^ wb, wa);
383 }
384
385 #[test]
386 fn prop_bitwise_not(a in any::<u64>()) {
387 let wa = Word(a);
388
389 assert_eq!((!wa).0, !a);
391 assert_eq!(!(!wa), wa); assert_eq!(!Word::ZERO, Word::ALL_ONE);
393 assert_eq!(!Word::ALL_ONE, Word::ZERO);
394
395 let wb = Word(a.wrapping_add(1));
397 assert_eq!(!(wa & wb), !wa | !wb);
398 assert_eq!(!(wa | wb), !wa & !wb);
399 }
400
401 #[test]
402 fn prop_shift_left(val in any::<u64>(), shift in 0u32..64) {
403 let w = Word(val);
404 assert_eq!((w << shift).0, val << shift);
405
406 assert_eq!(w << 0, w);
408
409 if shift >= 64 {
411 assert_eq!((w << shift).0, 0);
412 }
413 }
414
415 #[test]
416 fn prop_shift_right(val in any::<u64>(), shift in 0u32..64) {
417 let w = Word(val);
418 assert_eq!((w >> shift).0, val >> shift);
419
420 assert_eq!(w >> 0, w);
422
423 if shift >= 64 {
425 assert_eq!((w >> shift).0, 0);
426 }
427 }
428
429 #[test]
430 fn prop_shift_inverse(val in any::<u64>(), shift in 1u32..64) {
431 let w = Word(val);
432 let mask = (1u64 << (64 - shift)) - 1;
434 assert_eq!(((w << shift) >> shift).0, val & mask);
435
436 let high_mask = !((1u64 << shift) - 1);
438 assert_eq!(((w >> shift) << shift).0, val & high_mask);
439 }
440
441 #[test]
442 fn prop_sar(val in any::<u64>(), shift in 0u32..64) {
443 let w = Word(val);
444 let expected = ((val as i64) >> shift) as u64;
445 assert_eq!(w.sar(shift).0, expected);
446
447 assert_eq!(w.sar(0), w);
449
450 let sign_extended = if (val as i64) < 0 {
452 Word(0xFFFFFFFFFFFFFFFF)
453 } else {
454 Word(0)
455 };
456 assert_eq!(w.sar(63), sign_extended);
457 }
458
459 #[test]
460 fn prop_sar_sign_extension(val in any::<u64>(), shift in 1u32..64) {
461 let w = Word(val);
462 let result = w.sar(shift);
463
464 let is_negative = (val as i64) < 0;
466 if is_negative {
467 let mask = !((1u64 << (64 - shift)) - 1);
469 assert_eq!(result.0 & mask, mask);
470 } else {
471 let mask = !((1u64 << (64 - shift)) - 1);
473 assert_eq!(result.0 & mask, 0);
474 }
475 }
476
477 #[test]
478 fn prop_iadd_cout_32(a in any::<u32>(), b in any::<u32>()) {
479 let wa = Word(a as u64);
480 let wb = Word(b as u64);
481 let (sum, cout) = wa.iadd_cout_32(wb);
482
483 assert_eq!(sum.0, (a as u64 + b as u64) & 0xFFFFFFFF);
485
486 let expected_cout = (a as u64 & b as u64) | ((a as u64 ^ b as u64) & !sum.0);
488 assert_eq!(cout.0, expected_cout);
489
490 let (sum0, cout0) = wa.iadd_cout_32(Word::ZERO);
492 assert_eq!(sum0.0, a as u64);
493 assert_eq!(cout0, Word::ZERO);
494 }
495
496 #[test]
497 fn prop_iadd_cin_cout(a in any::<u64>(), b in any::<u64>(), cin in 0u64..=1) {
498 let wa = Word(a);
499 let wb = Word(b);
500 let wcin = Word(cin);
501 let (sum, cout) = wa.iadd_cin_cout(wb, wcin);
502
503 let expected_sum = a.wrapping_add(b).wrapping_add(cin);
505 assert_eq!(sum.0, expected_sum);
506
507 let expected_cout = (a & b) | ((a ^ b) & !expected_sum);
509 assert_eq!(cout.0, expected_cout);
510
511 let (sum0, cout0) = wa.iadd_cin_cout(wb, Word::ZERO);
513 let full_sum = a.wrapping_add(b);
514 assert_eq!(sum0.0, full_sum);
515 assert_eq!(cout0.0, (a & b) | ((a ^ b) & !full_sum));
516 }
517
518 #[test]
519 fn prop_isub_bin_bout(a in any::<u64>(), b in any::<u64>(), bin in 0u64..=1) {
520 let wa = Word(a);
521 let wb = Word(b);
522 let wbin = Word(bin);
523 let (diff, bout) = wa.isub_bin_bout(wb, wbin);
524
525 let expected_diff = a.wrapping_sub(b).wrapping_sub(bin);
527 assert_eq!(diff.0, expected_diff);
528
529 let expected_bout = (!a & b) | (!(a ^ b) & expected_diff);
531 assert_eq!(bout.0, expected_bout);
532
533 let (diff0, bout0) = wa.isub_bin_bout(wb, Word::ZERO);
535 let expected = a.wrapping_sub(b);
536 assert_eq!(diff0.0, expected);
537 assert_eq!(bout0.0, (!a & b) | (!(a ^ b) & expected));
538 }
539
540 #[test]
541 fn prop_shr_32(val in any::<u64>(), shift in 0u32..64) {
542 let w = Word(val);
543 let result = w.shr_32(shift);
544
545 let expected = (val >> shift) & 0xFFFFFFFF;
547 assert_eq!(result.0, expected);
548
549 assert_eq!(w.shr_32(0).0, val & 0xFFFFFFFF);
551
552 if shift >= 32 {
554 assert_eq!(result.0, (val >> shift) & 0xFFFFFFFF);
555 }
556 }
557
558 #[test]
559 fn prop_rotr_32(val in any::<u32>(), rotate in 0u32..64) {
560 let w = Word(val as u64);
561 let result = w.rotr_32(rotate);
562
563 let rotate_mod = rotate % 32;
565 let val32 = val as u64;
566 let expected = if rotate_mod == 0 {
567 val32
568 } else {
569 ((val32 >> rotate_mod) | (val32 << (32 - rotate_mod))) & 0xFFFFFFFF
570 };
571 assert_eq!(result.0, expected);
572
573 assert_eq!(w.rotr_32(0).0, val32);
575 assert_eq!(w.rotr_32(32).0, val32);
576 }
577
578 #[test]
579 fn prop_rotr(val in any::<u64>(), rotate in 0u32..128) {
580 let w = Word(val);
581 let result = w.rotr(rotate);
582
583 let rotate_mod = rotate % 64;
585 let expected = if rotate_mod == 0 {
586 val
587 } else {
588 (val >> rotate_mod) | (val << (64 - rotate_mod))
589 };
590 assert_eq!(result.0, expected);
591
592 assert_eq!(w.rotr(0), w);
594 assert_eq!(w.rotr(64), w);
595
596 let r1 = rotate % 64;
598 let r2 = (64 - r1) % 64;
599 if r1 != 0 {
600 assert_eq!(w.rotr(r1).rotr(r2), w);
601 }
602 }
603
604 #[test]
605 fn prop_imul(a in any::<u64>(), b in any::<u64>()) {
606 let wa = Word(a);
607 let wb = Word(b);
608 let (hi, lo) = wa.imul(wb);
609
610 let result = (a as u128) * (b as u128);
612 assert_eq!(hi.0, (result >> 64) as u64);
613 assert_eq!(lo.0, result as u64);
614
615 let (hi0, lo0) = wa.imul(Word::ZERO);
617 assert_eq!(hi0, Word::ZERO);
618 assert_eq!(lo0, Word::ZERO);
619
620 let (hi1, lo1) = wa.imul(Word::ONE);
622 assert_eq!(hi1, Word::ZERO);
623 assert_eq!(lo1, wa);
624
625 let (hi_ab, lo_ab) = wa.imul(wb);
627 let (hi_reversed, lo_reversed) = wb.imul(wa);
628 assert_eq!(hi_ab, hi_reversed);
629 assert_eq!(lo_ab, lo_reversed);
630 }
631
632 #[test]
633 fn prop_smul(a in any::<u64>(), b in any::<u64>()) {
634 let wa = Word(a);
635 let wb = Word(b);
636 let (hi, lo) = wa.smul(wb);
637
638 let result = (a as i64 as i128) * (b as i64 as i128);
640 assert_eq!(hi.0, (result >> 64) as u64);
641 assert_eq!(lo.0, result as u64);
642
643 let (hi0, lo0) = wa.smul(Word::ZERO);
645 assert_eq!(hi0, Word::ZERO);
646 assert_eq!(lo0, Word::ZERO);
647
648 let (hi1, lo1) = wa.smul(Word::ONE);
650 let expected_hi = if (a as i64) < 0 { Word(0xFFFFFFFFFFFFFFFF) } else { Word::ZERO };
651 assert_eq!(hi1, expected_hi);
652 assert_eq!(lo1, wa);
653
654 let (hi_neg, lo_neg) = wa.smul(Word(0xFFFFFFFFFFFFFFFF));
656 let neg_result = -(a as i64 as i128);
657 assert_eq!(hi_neg.0, (neg_result >> 64) as u64);
658 assert_eq!(lo_neg.0, neg_result as u64);
659
660 let (hi_ab, lo_ab) = wa.smul(wb);
662 let (hi_reversed, lo_reversed) = wb.smul(wa);
663 assert_eq!(hi_ab, hi_reversed);
664 assert_eq!(lo_ab, lo_reversed);
665 }
666
667 #[test]
668 fn prop_wrapping_sub(a in any::<u64>(), b in any::<u64>()) {
669 let wa = Word(a);
670 let wb = Word(b);
671 let result = wa.wrapping_sub(wb);
672
673 assert_eq!(result.0, a.wrapping_sub(b));
674
675 assert_eq!(wa.wrapping_sub(Word::ZERO), wa);
677
678 assert_eq!(wa.wrapping_sub(wa), Word::ZERO);
680
681 let sum = Word(a.wrapping_add(b));
683 assert_eq!(sum.wrapping_sub(wb), wa);
684 }
685
686 #[test]
687 fn prop_conversions(val in any::<u64>()) {
688 let word = Word::from_u64(val);
689 assert_eq!(word.as_u64(), val);
690 assert_eq!(word, Word(val));
691
692 assert_eq!(Word::from_u64(word.as_u64()), word);
694 }
695
696 #[test]
697 fn prop_debug_format(val in any::<u64>()) {
698 let word = Word(val);
699 let debug_str = format!("{:?}", word);
700 assert!(debug_str.starts_with("Word(0x"));
701 assert!(debug_str.ends_with(")"));
702 let expected = format!("Word({:#018x})", val);
704 assert_eq!(debug_str, expected);
705 }
706 }
707}