binius_field/arch/portable/byte_sliced/
invert.rs

1// Copyright 2024-2025 Irreducible Inc.
2use super::{
3	multiply::{mul_alpha, mul_main},
4	square::square_main,
5};
6use crate::{
7	tower_levels::{TowerLevel, TowerLevelWithArithOps},
8	underlier::WithUnderlier,
9	AESTowerField8b, PackedField,
10};
11
12#[inline(always)]
13pub fn invert_or_zero<P: PackedField<Scalar = AESTowerField8b>, Level: TowerLevel>(
14	field_element: &Level::Data<P>,
15	destination: &mut Level::Data<P>,
16) {
17	let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3));
18
19	inv_main::<P, Level>(field_element, destination, base_alpha);
20}
21
22#[inline(always)]
23fn inv_main<P: PackedField<Scalar = AESTowerField8b>, Level: TowerLevel>(
24	field_element: &Level::Data<P>,
25	destination: &mut Level::Data<P>,
26	base_alpha: P,
27) {
28	if Level::WIDTH == 1 {
29		destination.as_mut()[0] = field_element.as_ref()[0].invert_or_zero();
30		return;
31	}
32
33	let (a0, a1) = Level::split(field_element);
34
35	let (result0, result1) = Level::split_mut(destination);
36
37	let mut intermediate = <<Level as TowerLevel>::Base as TowerLevel>::default();
38
39	// intermediate = subfield_alpha*a1
40	mul_alpha::<true, P, Level::Base>(a1, &mut intermediate, base_alpha);
41
42	// intermediate = a0 + subfield_alpha*a1
43	Level::Base::add_into(a0, &mut intermediate);
44
45	let mut delta = <<Level as TowerLevel>::Base as TowerLevel>::default();
46
47	// delta = intermediate * a0
48	mul_main::<true, P, Level::Base>(&intermediate, a0, &mut delta, base_alpha);
49
50	// delta = intermediate * a0 + a1^2
51	square_main::<false, P, Level::Base>(a1, &mut delta, base_alpha);
52
53	let mut delta_inv = <<Level as TowerLevel>::Base as TowerLevel>::default();
54
55	// delta_inv = 1/delta
56	inv_main::<P, Level::Base>(&delta, &mut delta_inv, base_alpha);
57
58	// result0 = delta_inv*intermediate
59	mul_main::<true, P, Level::Base>(&delta_inv, &intermediate, result0, base_alpha);
60
61	// result1 = delta_inv*intermediate
62	mul_main::<true, P, Level::Base>(&delta_inv, a1, result1, base_alpha);
63}