我应该如何在软件中实现通用的 FMA/FMAF 指令?



FMA是一个融合的乘加指令。glibc中的fmaf (float x, float y, float z)函数调用vfmadd213ss指令。我想知道这个指令是如何实现的。据我了解:

  1. 添加xy的指数。
  2. 乘以xy的尾数.
  3. 规范化了x*y的结果,但不四舍五入。
  4. 比较z的指数并移动较小指数的尾数
  5. 添加尾数,结果再次归一化
  6. 四舍五入(rn)。

当前的 x86-64 架构实现了所谓的 FMA3 变体: 由于融合乘法加法运算需要三个源操作数,如果实现它的指令总共只有三个操作数,则必须指定哪个源操作数也是目标:vfmadd123ssvfmadd213ssvfmadd231ss

在数学功能方面,这些指令都是等价的,并且通过单次舍入计算 a*b+c,其中 a、b 和 c 是 IEEE-754binary32(单精度)操作数,在 C 和 C++ 系列的编程语言中通常映射到float

问题中提供的顶级算法大纲是正确的。下面的代码演示了如何在以下限制下实现所有必要的细节:关闭 IEEE-754 浮点异常(即代码提供 IEEE-754 标准规定的屏蔽响应),并打开次正常支持(许多平台,包括 x86-64,也支持非标准的"刷新到零"和"异常为零"模式)。当向 FMA 提供多个 NaN 源操作数时,其中任何一个或规范的 NaN 构成结果的基础。在下面的代码中,我只是简单地匹配了工作站中至强 W2133 CPU 的行为;其他处理器可能需要进行调整。

下面的代码是花园品种的 FMA 仿真,编码是为了合理的性能和合理的清晰度。如果平台提供CLZ(计数前导零)或相关指令,则通过内部函数对其进行接口会有所帮助。正确的舍入在很大程度上取决于对圆形和粘性位的正确跟踪。在硬件中,这些通常是两个实际位,但对于软件仿真,使用整个无符号整数(rndstk在下面的代码中)通常很有用,其中最重要的整数表示舍入位,所有剩余的低阶位表示集合(即。OR一起)表示粘性位。

对于实际(更快)的fmaf()仿真,通常依赖于在IEEE-754binary64(双精度)中执行中间计算。这导致了棘手的双舍五入问题,并且并非所有常见开源库中的实现都能正常工作。从文献中知道的最佳方法是在中间计算中使用特殊的舍入模式,四舍五入为奇数。看:

Sylvie Boldo 和 Guillaume Melquiond,"模拟 FMA 和正确舍入的总和:使用舍入到奇数的证明算法",IEEE Transactions on Computers,第 57 卷,第 4 期,2008 年 2 月,第 462-471 页。

严格测试 FMA 实现,无论是硬件还是软件,都是一个难题。由于搜索空间巨大,因此简单地使用大量(例如数百亿)随机测试向量将仅提供"烟雾"测试,这对于证明实现不会无可救药地被破坏很有用。下面我将添加基于模式的测试,这些测试能够执行许多极端情况。尽管如此,下面的代码应该只考虑经过轻微测试

如果您打算以任何专业身份使用 FMA 仿真,我强烈建议您投入大量时间来确保功能正确性;坦率地说,已经有太多损坏的 FMA 仿真。对于工业强度的实施,硬件供应商采用经过机械检查的正确操作数学证明。这本书由一位经验丰富的从业者撰写,很好地概述了在实践中如何运作:

David M. Russinoff,浮点设计的形式验证:一种数学方法,Springer 2019。

还有用于处理许多极端情况的专业测试套件,例如来自海法 IBM 研究实验室的 FPgen 浮点测试生成器。他们曾经在他们的网站上免费提供单精度测试向量的集合,但我似乎再也找不到它们了。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <limits.h>
#include <math.h>
#define PURELY_RANDOM  (0)
#define PATTERN_BASED  (1)
#define TEST_MODE      (PURELY_RANDOM)
#define ROUND_MODE     (roundNearest) 
/* vvvvvvvvvvvvvvvvvvvvvvvvvvv x86-64 specific vvvvvvvvvvvvvvvvvvvvvvvvvvvv */
#include "immintrin.h"
#define roundMinInf  (_MM_ROUND_DOWN)
#define roundPosInf  (_MM_ROUND_UP)
#define roundZero    (_MM_ROUND_TOWARD_ZERO)
#define roundNearest (_MM_ROUND_NEAREST)
#define ftzOff       (_MM_FLUSH_ZERO_OFF)
#define dazOff       (_MM_DENORMALS_ZERO_OFF)
void set_subnormal_support (uint32_t ftz, uint32_t daz)
{
_MM_SET_DENORMALS_ZERO_MODE (ftz);
_MM_SET_FLUSH_ZERO_MODE (daz);
}
float ref_fmaf (float a, float b, float c, uint32_t rnd)
{
__m128 r, s, t, u;
float res;
uint32_t old_mxcsr;
old_mxcsr = _mm_getcsr();
_MM_SET_ROUNDING_MODE (rnd);
s = _mm_set_ss (a);
t = _mm_set_ss (b);
u = _mm_set_ss (c);
r = _mm_fmadd_ss (s, t, u);
_mm_store_ss (&res, r);
_mm_setcsr (old_mxcsr);
return res;
}
/* ^^^^^^^^^^^^^^^^^^^^^^^^^^^ x86-64 specific ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ */
/* re-interpret bits of IEEE-754 'binary32' as unsigned 32-bit integer */
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
/* re-interpret bits of unsigned 32-bit integer as IEEE-754 'binary32' */
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
/* 32-bit leading zero count. Use platform-specific intrinsic if available */
int clz32 (uint32_t a)
{
int n = 32;
if (a >= 0x00010000u) { a >>= 16;  n -= 16; }
if (a >= 0x00000100u) { a >>=  8;  n -=  8; }
if (a >= 0x00000010u) { a >>=  4;  n -=  4; }
if (a >= 0x00000004u) { a >>=  2;  n -=  2; }
n -= a & ~(a >> 1);
return n;
}
/* 64-bit leading zero count. Use platform-specific intrinsic if available */
int clz64 (uint64_t a)
{
uint32_t hi = (uint32_t)(a >> 32);
uint32_t lo = (uint32_t)(a & 0xffffffff);
return hi ? clz32 (hi) : (32 + clz32 (lo));
}
/* full product of two 32-bit unsigned integers. May use platform intrinsic */
uint64_t mul_u32_wide (uint32_t a, uint32_t b)
{
return (uint64_t)a * b;
}
uint32_t fmaf_kernel (uint32_t x, uint32_t y, uint32_t z, int mode)     
{     
const uint32_t FP32_SIGN_BIT = 0x80000000;
const uint32_t FP32_POS_ZERO = 0x00000000;
const uint32_t FP32_NEG_ZERO = 0x80000000;
const uint32_t FP32_QNAN_BIT = 0x00400000;
const uint32_t FP32_INT_BIT  = 0x00800000;
const uint32_t FP32_EXPO_MASK = 0x7f800000;
const uint32_t FP32_POS_INFINITY = 0x7f800000;
const uint32_t FP32_NEG_INFINITY = 0xff800000;
const uint32_t FP32_POS_MAX_NORMAL = 0x7f7fffff;
const uint32_t FP32_NEG_MAX_NORMAL = 0xff7fffff;
const uint32_t FP32_QNAN_INDEFINITE = 0xffc00000;
const uint32_t FP32_EXPO_BIAS = 127;
const uint32_t FP32_STORED_MANT_BITS = 23;
const uint32_t FP32_EXPO_BITS = 8;
const uint32_t FP32_MAX_NORM_EXPO_M1 = 254 - 1;
uint64_t mant_p, templl;
uint32_t mant_x, mant_y, mant_z, mant_r;
uint32_t expo_x, expo_y, expo_z, expo_r, expo_p;
uint32_t sign_z, sign_p, sign_r;
uint32_t r, shift, lz, rndstk, z_zer, temp;

expo_x = ((x & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
expo_y = ((y & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
expo_z = ((z & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
z_zer = (z << 1) == 0x00000000;

if (!((expo_x <= FP32_MAX_NORM_EXPO_M1) &&
(expo_y <= FP32_MAX_NORM_EXPO_M1) &&
(expo_z <= FP32_MAX_NORM_EXPO_M1))) {
uint32_t x_nan = (x << 1) >  0xff000000;
uint32_t y_nan = (y << 1) >  0xff000000;
uint32_t z_nan = (z << 1) >  0xff000000;
uint32_t x_inf = (x << 1) == 0xff000000;
uint32_t y_inf = (y << 1) == 0xff000000;
uint32_t z_inf = (z << 1) == 0xff000000;
uint32_t x_zer = (x << 1) == 0x00000000;
uint32_t y_zer = (y << 1) == 0x00000000;

/* pass-through quietened NaN arguments */
if (y_nan) {
return y | FP32_QNAN_BIT;
}
if (x_nan) {
return x | FP32_QNAN_BIT;
}
if (z_nan) {
return z | FP32_QNAN_BIT;
}
/* invalid operations, bsed on zeros and infinities */
if (((x_zer && y_inf) || (y_zer && x_inf)) ||
(z_inf && (x_inf || y_inf) && ((int32_t)(x ^ y ^ z) < 0))) {
return FP32_QNAN_INDEFINITE;
}
/* infinity results */
if (x_inf) {
return x ^ (y & FP32_SIGN_BIT);
}
if (y_inf) {
return y ^ (x & FP32_SIGN_BIT);
}
if (z_inf) {
return z;
}
/* results of negative zero */
if ((z == FP32_NEG_ZERO) &&
(x_zer || y_zer) && ((int32_t)(x ^ y) < 0)) {
return z;
}
/* zero results */
if (z_zer && (x_zer || y_zer)) {
return ((mode == roundMinInf) ?
((x ^ y ^ z) & FP32_SIGN_BIT) : (z & ~FP32_SIGN_BIT));
}
/* product x*y is zero: pass-through z */
if (x_zer || y_zer) {
return z;
}
/* normalize x if subnormal */
if (expo_x == (uint32_t)-1) {    
temp = x << FP32_EXPO_BITS;
lz = clz32 (temp);
temp = temp << lz;
expo_x = expo_x - lz + 1;
x = (temp >> FP32_EXPO_BITS) | (x & FP32_SIGN_BIT);
}
/* normalize y if subnormal */
if (expo_y == (uint32_t)-1) {
temp = y << FP32_EXPO_BITS;
lz = clz32 (temp);
temp = temp << lz;
expo_y = expo_y - lz + 1;
y = (temp >> FP32_EXPO_BITS) | (y & FP32_SIGN_BIT);
}
/* normalize z if subnormal */
if ((expo_z == (uint32_t)-1) && (!z_zer)) {
temp = z << FP32_EXPO_BITS;
lz = clz32 (temp);
temp = temp << lz;
expo_z = expo_z - lz + 1;
z = (temp >> FP32_EXPO_BITS) | (z & FP32_SIGN_BIT);
}
}
/* multiply x * y */
expo_p = expo_x + expo_y - FP32_EXPO_BIAS + 2;
sign_p = (x ^ y) & FP32_SIGN_BIT;
mant_x = (x & 0x00ffffff) | FP32_INT_BIT;
mant_y = (y << 8) | (FP32_INT_BIT << 8);
mant_p = mul_u32_wide (mant_x, mant_y);

/* normalize product x*y */
if (!(mant_p & ((uint64_t)FP32_INT_BIT << 32))) {
mant_p = mant_p << 1;
expo_p--;
}
/* add z to produxt x*y */
if (z_zer) {
expo_r = expo_p;
sign_r = sign_p;
mant_r = (uint32_t)(mant_p >> 32);
rndstk = (uint32_t)(mant_p);
} else {
sign_z = z & FP32_SIGN_BIT;
mant_z = (z & 0x00ffffff) | FP32_INT_BIT;
uint64_t large, small, mant_z_ext = (uint64_t)mant_z << 32;
/* sort summands by magnitude of significands */
if (((int)expo_p > (int)expo_z) ||
((expo_p == expo_z) && (mant_p > mant_z_ext))) {
expo_r = expo_p;
sign_r = sign_p;
large = mant_p;
small = mant_z_ext;
shift = expo_p - expo_z;
} else {
expo_r = expo_z;
sign_r = sign_z;
large = mant_z_ext;
small = mant_p;
shift = expo_z - expo_p;
}
/* denormalize small */
if (shift == 0) {
rndstk = 0;
} else if (shift > 63) {
rndstk = 1; // only sticky
small = 0;
} else {
templl = small << (64 - shift);
rndstk = (uint32_t)(templl >> 32) | (((uint32_t)templl) ? 1 : 0);
small = small >> shift;
}
/* add or subtract significants */
if (sign_p != sign_z) {
large = large - small - (rndstk ? 1 : 0);
/* complete cancelation: return 0 */
if (large == 0) {
return (mode == roundMinInf) ? FP32_NEG_ZERO : FP32_POS_ZERO;
}
/* normalize mantissa if necessary */
if (!(large & ((uint64_t)FP32_INT_BIT << 32))) {
lz = clz64 (large);
shift = lz - 8;
large = large << shift;
expo_r = expo_r - shift;
}
} else {
large = large + small;
/* normalize mantissa if necessary */
if (large & 0x0100000000000000ULL) {
templl = large << 63;
rndstk = (uint32_t)(templl >> 32) | (rndstk ? 1 : 0);
large = large >> 1;
expo_r++;
}
}
mant_r = (uint32_t)(large >> 32);
rndstk = (uint32_t)(large) | (rndstk ? 1 : 0);
}
/* round result */
if (expo_r <= FP32_MAX_NORM_EXPO_M1) { // normal
if (mode == roundNearest) {
mant_r += (rndstk == 0x80000000) ? (mant_r & 1) : (rndstk >> 31);
} else if (mode == roundPosInf) {
mant_r += rndstk && !sign_r;
} else if (mode == roundMinInf) {
mant_r += rndstk && sign_r;
} else { // mode == roundZero
}
r = sign_r + mant_r + (expo_r << 23);
return r;
} else if ((int32_t)expo_r >= 0) { // overflow: largest normal or infinity
if (mode == roundNearest) {
r = sign_r | FP32_POS_INFINITY;
} else if (mode == roundZero) {
r = sign_r | FP32_POS_MAX_NORMAL;
} else if (mode == roundPosInf) {
r = sign_r ? FP32_NEG_MAX_NORMAL : FP32_POS_INFINITY;
} else { // (mode == roundMinInf)
r = sign_r ? FP32_NEG_INFINITY : FP32_POS_MAX_NORMAL;
}
return r;
} else { /* underflow: smallest normal, subnormal, or zero */
shift = 0 - expo_r;
rndstk = (shift > 25) ? 1 : ((mant_r << (32 - shift)) | (rndstk ? 1 : 0));
mant_r = (shift > 25) ? 0 : (mant_r >> shift);
if (mode == roundNearest) {
mant_r += ((rndstk == 0x80000000) ? (mant_r & 1) : (rndstk >> 31));
} else if (mode == roundPosInf) {
mant_r += rndstk && !sign_r;
} else if (mode == roundMinInf) {
mant_r += rndstk && sign_r;
} else { // mode == roundZero
}
r = sign_r + mant_r;
}
return r;
}     
float my_fmaf (float a, float b, float c, uint32_t rnd)
{
return uint32_as_float (fmaf_kernel (float_as_uint32 (a),
float_as_uint32 (b),
float_as_uint32 (c),
rnd));
}   
uint32_t v[8192];
// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535)+(kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535)+(kiss_w>>16))
#define MWC  ((znew<<16)+wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), 
kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong+1234567)
#define KISS ((MWC^CONG)+SHR3)
int main (void)
{
const uint32_t rnd = ROUND_MODE;
unsigned long long count = 0;
float a, b, c, res, ref;
const uint32_t nbrBits = sizeof (uint32_t) * CHAR_BIT;
uint32_t i, j, patterns, idx = 0;
uint32_t ai, bi, ci, resi, refi;
/* pattern class 1: 2**i */
for (i = 0; i < nbrBits; i++) {
v [idx] = ((uint32_t)1 << i);
idx++;
}
/* pattern class 2: 2**i-1 */
for (i = 0; i < nbrBits; i++) {
v [idx] = (((uint32_t)1 << i) - 1);
idx++;
}
/* pattern class 3: 2**i+1 */
for (i = 0; i < nbrBits; i++) {
v [idx] = (((uint32_t)1 << i) + 1);
idx++;
}
/* pattern class 4: 2**i + 2**j */
for (i = 0; i < nbrBits; i++) {
for (j = 0; j < nbrBits; j++) {
v [idx] = (((uint32_t)1 << i) + ((uint32_t)1 << j));
idx++;
}
}
/* pattern class 5: 2**i - 2**j */
for (i = 0; i < nbrBits; i++) {
for (j = 0; j < nbrBits; j++) {
v [idx] = (((uint32_t)1 << i) - ((uint32_t)1 << j));
idx++;
}
}
/* pattern class 6: MAX_UINT/(2**i+1) rep. blocks of i zeros an i ones */
for (i = 0; i < nbrBits; i++) {
v [idx] = ((~(uint32_t)0) / (((uint32_t)1 << i) + 1));
idx++;
}
patterns = idx;
/* pattern class 6: one's complement of pattern classes 1 through 5 */
for (i = 0; i < patterns; i++) {
v [idx] = ~v [i];
idx++;
}
/* pattern class 7: two's complement of pattern classes 1 through 5 */
for (i = 0; i < patterns; i++) {
v [idx] = ~v [i] + 1;
idx++;
}
patterns = idx;
printf ("testing single-precision FMAn");
printf ("rounding mode: ");
if (rnd == roundZero) {
printf ("toward zero (truncate)n");
} else if (rnd == roundNearest) {
printf ("round to nearest, ties to evenn");
} else if (rnd == roundPosInf) {
printf ("round up (toward positive infinity)n");
} else if (rnd == roundMinInf) {
printf ("round down (toward negative infinity)n");
} else {
printf ("unsupportedn");
return EXIT_FAILURE;
}
#if TEST_MODE == PURELY_RANDOM
printf ("using purely random test vectorsn");
#elif TEST_MODE == PATTERN_BASED
printf ("using pattern-based test vectorsn");
printf ("#patterns = %un", patterns);
#endif // TEST_MODE
/* make sure subnormal support is turned on */
set_subnormal_support (ftzOff, dazOff);
do {
#if TEST_MODE == PURELY_RANDOM
ai = KISS;
bi = KISS;
ci = KISS;
#elif TEST_MODE == PATTERN_BASED
ai = KISS;
bi = KISS;
ci = KISS;
ai = ((v[ai%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
bi = ((v[bi%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
ci = ((v[ci%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
#endif // TEST_MODE
a = uint32_as_float (ai);
b = uint32_as_float (bi);
c = uint32_as_float (ci);
res = my_fmaf (a, b, c, rnd);
ref = ref_fmaf (a, b, c, rnd);
resi = float_as_uint32 (res);
refi = float_as_uint32 (ref);
if (!(resi == refi)) {
printf ("!!!! error @ a=%08x (% 15.8e)  b=%08x (% 15.8e)  c=%08x (% 15.8e)  res = %08x (% 15.8e)  ref = %08x (% 15.8e)n",
ai, a, bi, b, ci, c, resi, res, refi, ref);
return EXIT_FAILURE;
}
count++;
if (!(count & 0xffffff)) printf ("r%llu", count);
} while (1);
return EXIT_SUCCESS;
}

最新更新