Skip to content

Commit f456aa8

Browse files
committed
Refactor the fma modules
Move implementations to `generic/` like the other functions. This also allows us to combine the `fma` and `fma_wide` modules.
1 parent 91963f5 commit f456aa8

File tree

6 files changed

+179
-175
lines changed

6 files changed

+179
-175
lines changed

etc/function-definitions.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@
350350
"fmaf": {
351351
"sources": [
352352
"libm/src/math/arch/aarch64.rs",
353-
"libm/src/math/fma_wide.rs"
353+
"libm/src/math/fma.rs"
354354
],
355355
"type": "f32"
356356
},

libm/src/math/fma.rs

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* SPDX-License-Identifier: MIT */
2+
/* origin: musl src/math/fma.c, fmaf.c Ported to generic Rust algorithm in 2025, TG. */
3+
4+
use super::generic;
5+
use crate::support::Round;
6+
7+
// Placeholder so we can have `fmaf16` in the `Float` trait.
8+
#[allow(unused)]
9+
#[cfg(f16_enabled)]
10+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
11+
pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
12+
unimplemented!()
13+
}
14+
15+
/// Floating multiply add (f32)
16+
///
17+
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
18+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
19+
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
20+
select_implementation! {
21+
name: fmaf,
22+
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
23+
args: x, y, z,
24+
}
25+
26+
generic::fma_wide_round(x, y, z, Round::Nearest).val
27+
}
28+
29+
/// Fused multiply add (f64)
30+
///
31+
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
32+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
33+
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
34+
select_implementation! {
35+
name: fma,
36+
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
37+
args: x, y, z,
38+
}
39+
40+
generic::fma_round(x, y, z, Round::Nearest).val
41+
}
42+
43+
/// Fused multiply add (f128)
44+
///
45+
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
46+
#[cfg(f128_enabled)]
47+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
48+
pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 {
49+
generic::fma_round(x, y, z, Round::Nearest).val
50+
}
51+
52+
#[cfg(test)]
53+
mod tests {
54+
use super::*;
55+
use crate::support::{CastFrom, CastInto, Float, FpResult, HInt, MinInt, Round, Status};
56+
57+
/// Test the generic `fma_round` algorithm for a given float.
58+
fn spec_test<F>(f: impl Fn(F, F, F) -> F)
59+
where
60+
F: Float,
61+
F: CastFrom<F::SignedInt>,
62+
F: CastFrom<i8>,
63+
F::Int: HInt,
64+
u32: CastInto<F::Int>,
65+
{
66+
let x = F::from_bits(F::Int::ONE);
67+
let y = F::from_bits(F::Int::ONE);
68+
let z = F::ZERO;
69+
70+
// 754-2020 says "When the exact result of (a × b) + c is non-zero yet the result of
71+
// fusedMultiplyAdd is zero because of rounding, the zero result takes the sign of the
72+
// exact result"
73+
assert_biteq!(f(x, y, z), F::ZERO);
74+
assert_biteq!(f(x, -y, z), F::NEG_ZERO);
75+
assert_biteq!(f(-x, y, z), F::NEG_ZERO);
76+
assert_biteq!(f(-x, -y, z), F::ZERO);
77+
}
78+
79+
#[test]
80+
fn spec_test_f32() {
81+
spec_test::<f32>(fmaf);
82+
83+
// Also do a small check that the non-widening version works for f32 (this should ideally
84+
// get tested some more).
85+
spec_test::<f32>(|x, y, z| generic::fma_round(x, y, z, Round::Nearest).val);
86+
}
87+
88+
#[test]
89+
fn spec_test_f64() {
90+
spec_test::<f64>(fma);
91+
92+
let expect_underflow = [
93+
(
94+
hf64!("0x1.0p-1070"),
95+
hf64!("0x1.0p-1070"),
96+
hf64!("0x1.ffffffffffffp-1023"),
97+
hf64!("0x0.ffffffffffff8p-1022"),
98+
),
99+
(
100+
// FIXME: we raise underflow but this should only be inexact (based on C and
101+
// `rustc_apfloat`).
102+
hf64!("0x1.0p-1070"),
103+
hf64!("0x1.0p-1070"),
104+
hf64!("-0x1.0p-1022"),
105+
hf64!("-0x1.0p-1022"),
106+
),
107+
];
108+
109+
for (x, y, z, res) in expect_underflow {
110+
let FpResult { val, status } = generic::fma_round(x, y, z, Round::Nearest);
111+
assert_biteq!(val, res);
112+
assert_eq!(status, Status::UNDERFLOW);
113+
}
114+
}
115+
116+
#[test]
117+
#[cfg(f128_enabled)]
118+
fn spec_test_f128() {
119+
spec_test::<f128>(fmaf128);
120+
}
121+
122+
#[test]
123+
fn issue_263() {
124+
let a = f32::from_bits(1266679807);
125+
let b = f32::from_bits(1300234242);
126+
let c = f32::from_bits(1115553792);
127+
let expected = f32::from_bits(1501560833);
128+
assert_eq!(fmaf(a, b, c), expected);
129+
}
130+
131+
#[test]
132+
fn fma_segfault() {
133+
// These two inputs cause fma to segfault on release due to overflow:
134+
assert_eq!(
135+
fma(
136+
-0.0000000000000002220446049250313,
137+
-0.0000000000000002220446049250313,
138+
-0.0000000000000002220446049250313
139+
),
140+
-0.00000000000000022204460492503126,
141+
);
142+
143+
let result = fma(-0.992, -0.992, -0.992);
144+
//force rounding to storage format on x87 to prevent superious errors.
145+
#[cfg(all(target_arch = "x86", not(target_feature = "sse2")))]
146+
let result = force_eval!(result);
147+
assert_eq!(result, -0.007936000000000007,);
148+
}
149+
150+
#[test]
151+
fn fma_sbb() {
152+
assert_eq!(
153+
fma(-(1.0 - f64::EPSILON), f64::MIN, f64::MIN),
154+
-3991680619069439e277
155+
);
156+
}
157+
158+
#[test]
159+
fn fma_underflow() {
160+
assert_eq!(
161+
fma(1.1102230246251565e-16, -9.812526705433188e-305, 1.0894e-320),
162+
0.0,
163+
);
164+
}
165+
}

libm/src/math/generic/fma.rs

+4-129
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,9 @@
11
/* SPDX-License-Identifier: MIT */
22
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
33

4-
use super::support::{DInt, FpResult, HInt, IntTy, Round, Status};
5-
use super::{CastFrom, CastInto, Float, Int, MinInt};
6-
7-
/// Fused multiply add (f64)
8-
///
9-
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
10-
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
11-
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
12-
select_implementation! {
13-
name: fma,
14-
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
15-
args: x, y, z,
16-
}
17-
18-
fma_round(x, y, z, Round::Nearest).val
19-
}
20-
21-
/// Fused multiply add (f128)
22-
///
23-
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
24-
#[cfg(f128_enabled)]
25-
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
26-
pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 {
27-
fma_round(x, y, z, Round::Nearest).val
28-
}
4+
use crate::support::{
5+
CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status,
6+
};
297

308
/// Fused multiply-add that works when there is not a larger float size available. Computes
319
/// `(x * y) + z`.
@@ -234,7 +212,7 @@ where
234212
}
235213

236214
// Use our exponent to scale the final value.
237-
FpResult::new(super::generic::scalbn(r, e), status)
215+
FpResult::new(super::scalbn(r, e), status)
238216
}
239217

240218
/// Representation of `F` that has handled subnormals.
@@ -298,106 +276,3 @@ impl<F: Float> Norm<F> {
298276
self.e > Self::ZERO_INF_NAN as i32
299277
}
300278
}
301-
302-
#[cfg(test)]
303-
mod tests {
304-
use super::*;
305-
306-
/// Test the generic `fma_round` algorithm for a given float.
307-
fn spec_test<F>()
308-
where
309-
F: Float,
310-
F: CastFrom<F::SignedInt>,
311-
F: CastFrom<i8>,
312-
F::Int: HInt,
313-
u32: CastInto<F::Int>,
314-
{
315-
let x = F::from_bits(F::Int::ONE);
316-
let y = F::from_bits(F::Int::ONE);
317-
let z = F::ZERO;
318-
319-
let fma = |x, y, z| fma_round(x, y, z, Round::Nearest).val;
320-
321-
// 754-2020 says "When the exact result of (a × b) + c is non-zero yet the result of
322-
// fusedMultiplyAdd is zero because of rounding, the zero result takes the sign of the
323-
// exact result"
324-
assert_biteq!(fma(x, y, z), F::ZERO);
325-
assert_biteq!(fma(x, -y, z), F::NEG_ZERO);
326-
assert_biteq!(fma(-x, y, z), F::NEG_ZERO);
327-
assert_biteq!(fma(-x, -y, z), F::ZERO);
328-
}
329-
330-
#[test]
331-
fn spec_test_f32() {
332-
spec_test::<f32>();
333-
}
334-
335-
#[test]
336-
fn spec_test_f64() {
337-
spec_test::<f64>();
338-
339-
let expect_underflow = [
340-
(
341-
hf64!("0x1.0p-1070"),
342-
hf64!("0x1.0p-1070"),
343-
hf64!("0x1.ffffffffffffp-1023"),
344-
hf64!("0x0.ffffffffffff8p-1022"),
345-
),
346-
(
347-
// FIXME: we raise underflow but this should only be inexact (based on C and
348-
// `rustc_apfloat`).
349-
hf64!("0x1.0p-1070"),
350-
hf64!("0x1.0p-1070"),
351-
hf64!("-0x1.0p-1022"),
352-
hf64!("-0x1.0p-1022"),
353-
),
354-
];
355-
356-
for (x, y, z, res) in expect_underflow {
357-
let FpResult { val, status } = fma_round(x, y, z, Round::Nearest);
358-
assert_biteq!(val, res);
359-
assert_eq!(status, Status::UNDERFLOW);
360-
}
361-
}
362-
363-
#[test]
364-
#[cfg(f128_enabled)]
365-
fn spec_test_f128() {
366-
spec_test::<f128>();
367-
}
368-
369-
#[test]
370-
fn fma_segfault() {
371-
// These two inputs cause fma to segfault on release due to overflow:
372-
assert_eq!(
373-
fma(
374-
-0.0000000000000002220446049250313,
375-
-0.0000000000000002220446049250313,
376-
-0.0000000000000002220446049250313
377-
),
378-
-0.00000000000000022204460492503126,
379-
);
380-
381-
let result = fma(-0.992, -0.992, -0.992);
382-
//force rounding to storage format on x87 to prevent superious errors.
383-
#[cfg(all(target_arch = "x86", not(target_feature = "sse2")))]
384-
let result = force_eval!(result);
385-
assert_eq!(result, -0.007936000000000007,);
386-
}
387-
388-
#[test]
389-
fn fma_sbb() {
390-
assert_eq!(
391-
fma(-(1.0 - f64::EPSILON), f64::MIN, f64::MIN),
392-
-3991680619069439e277
393-
);
394-
}
395-
396-
#[test]
397-
fn fma_underflow() {
398-
assert_eq!(
399-
fma(1.1102230246251565e-16, -9.812526705433188e-305, 1.0894e-320),
400-
0.0,
401-
);
402-
}
403-
}

libm/src/math/generic/fma_wide.rs

+3-41
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,6 @@
1-
/* SPDX-License-Identifier: MIT */
2-
/* origin: musl src/math/fmaf.c. Ported to generic Rust algorithm in 2025, TG. */
3-
4-
use super::support::{FpResult, IntTy, Round, Status};
5-
use super::{CastFrom, CastInto, DFloat, Float, HFloat, MinInt};
6-
7-
// Placeholder so we can have `fmaf16` in the `Float` trait.
8-
#[allow(unused)]
9-
#[cfg(f16_enabled)]
10-
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
11-
pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
12-
unimplemented!()
13-
}
14-
15-
/// Floating multiply add (f32)
16-
///
17-
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
18-
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
19-
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
20-
select_implementation! {
21-
name: fmaf,
22-
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
23-
args: x, y, z,
24-
}
25-
26-
fma_wide_round(x, y, z, Round::Nearest).val
27-
}
1+
use crate::support::{
2+
CastFrom, CastInto, DFloat, Float, FpResult, HFloat, IntTy, MinInt, Round, Status,
3+
};
284

295
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
306
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
@@ -95,17 +71,3 @@ where
9571

9672
FpResult::ok(B::from_bits(ui).narrow())
9773
}
98-
99-
#[cfg(test)]
100-
mod tests {
101-
use super::*;
102-
103-
#[test]
104-
fn issue_263() {
105-
let a = f32::from_bits(1266679807);
106-
let b = f32::from_bits(1300234242);
107-
let c = f32::from_bits(1115553792);
108-
let expected = f32::from_bits(1501560833);
109-
assert_eq!(fmaf(a, b, c), expected);
110-
}
111-
}

libm/src/math/generic/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mod copysign;
66
mod fabs;
77
mod fdim;
88
mod floor;
9+
mod fma;
10+
mod fma_wide;
911
mod fmax;
1012
mod fmaximum;
1113
mod fmaximum_num;
@@ -24,6 +26,8 @@ pub use copysign::copysign;
2426
pub use fabs::fabs;
2527
pub use fdim::fdim;
2628
pub use floor::floor;
29+
pub use fma::fma_round;
30+
pub use fma_wide::fma_wide_round;
2731
pub use fmax::fmax;
2832
pub use fmaximum::fmaximum;
2933
pub use fmaximum_num::fmaximum_num;

0 commit comments

Comments
 (0)