1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// Copyright 2024 Ulvetanna Inc.

use std::ops::Deref;

use crate::{
	arch::PairwiseStrategy,
	arithmetic_traits::{
		InvertOrZero, MulAlpha, Square, TaggedInvertOrZero, TaggedMul, TaggedMulAlpha,
		TaggedPackedTransformationFactory, TaggedSquare,
	},
	linear_transformation::{FieldLinearTransformation, Transformation},
	packed::{PackedBinaryField, PackedField},
};

impl<PT: PackedField> TaggedMul<PairwiseStrategy> for PT {
	#[inline]
	fn mul(self, b: Self) -> Self {
		if PT::WIDTH == 1 {
			// fallback to be able to benchmark this strategy
			self * b
		} else {
			Self::from_fn(|i| self.get(i) * b.get(i))
		}
	}
}

impl<PT: PackedField> TaggedSquare<PairwiseStrategy> for PT
where
	PT::Scalar: Square,
{
	#[inline]
	fn square(self) -> Self {
		if PT::WIDTH == 1 {
			// fallback to be able to benchmark this strategy
			PackedField::square(self)
		} else {
			Self::from_fn(|i| Square::square(self.get(i)))
		}
	}
}

impl<PT: PackedField> TaggedInvertOrZero<PairwiseStrategy> for PT
where
	PT::Scalar: InvertOrZero,
{
	#[inline]
	fn invert_or_zero(self) -> Self {
		if PT::WIDTH == 1 {
			// fallback to be able to benchmark this strategy
			PackedField::invert_or_zero(self)
		} else {
			Self::from_fn(|i| InvertOrZero::invert_or_zero(self.get(i)))
		}
	}
}

impl<PT: PackedField + MulAlpha> TaggedMulAlpha<PairwiseStrategy> for PT
where
	PT::Scalar: MulAlpha,
{
	#[inline]
	fn mul_alpha(self) -> Self {
		if PT::WIDTH == 1 {
			// fallback to be able to benchmark this strategy
			MulAlpha::mul_alpha(self)
		} else {
			Self::from_fn(|i| MulAlpha::mul_alpha(self.get(i)))
		}
	}
}

/// Per element transformation
pub struct PairwiseTransformation<I> {
	inner: I,
}

impl<I> PairwiseTransformation<I> {
	pub fn new(inner: I) -> Self {
		Self { inner }
	}
}

impl<IP, OP, IF, OF, I> Transformation<IP, OP> for PairwiseTransformation<I>
where
	IP: PackedField<Scalar = IF>,
	OP: PackedField<Scalar = OF>,
	I: Transformation<IF, OF>,
{
	fn transform(&self, data: &IP) -> OP {
		OP::from_fn(|i| self.inner.transform(&data.get(i)))
	}
}

impl<IP, OP> TaggedPackedTransformationFactory<PairwiseStrategy, OP> for IP
where
	IP: PackedBinaryField,
	OP: PackedBinaryField,
{
	type PackedTransformation<Data: Deref<Target = [OP::Scalar]>> =
		PairwiseTransformation<FieldLinearTransformation<OP::Scalar, Data>>;

	fn make_packed_transformation<Data: Deref<Target = [OP::Scalar]>>(
		transformation: FieldLinearTransformation<OP::Scalar, Data>,
	) -> Self::PackedTransformation<Data> {
		PairwiseTransformation::new(transformation)
	}
}