AXV2与大于32位的源没有任何整数乘法运算。它确实提供了32 x 32->32次乘法,以及32 x 32->64次乘法1,但没有64位源。
假设我需要一个输入大于32位,但小于或等于52位的无符号乘法-我可以简单地使用浮点DP乘法或FMA指令吗?当整数输入和结果可以用52位或更少的位表示时(即,在[0,2^52-1]范围内),输出会是位精确的吗?
更一般的情况下,我想要产品的全部104位,怎么样?或者整数乘积占用超过52位的情况(即,乘积在位索引>52中具有非零值),但我只想要低52位?在后一种情况下,MUL
将给我更高的比特,并四舍五入一些较低的比特(也许这就是IFMA的帮助?)。
编辑:事实上,根据这个答案,它可能可以做任何高达2^53的事情——我忘记了尾数之前隐含的前导1
实际上又给了你一点。
1有趣的是,正如Mysticial在评论中解释的那样,64位乘积PMULDQ
操作的延迟是32位PMULLD
版本的一半,吞吐量是32位版本的两倍。
是的,这是可能的。但从AVX2开始,它不太可能比MULX/ADCX/ADOX的标量方法更好。
对于不同的输入/输出域,这种方法几乎有无限数量的变体。我只介绍其中的3个,但一旦你知道它们是如何工作的,它们就很容易概括。
免责声明:
- 这里的所有解决方案都假设舍入模式是四舍五入到偶数
- 不建议使用快速数学优化标志,因为这些解决方案依赖于严格的IEEE
范围内的有符号双打:[-251,251]
// A*B = L + H*2^52
// Input: A and B are in the range [-2^51, 2^51]
// Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256d& L, __m256d& H, __m256d A, __m256d B){
const __m256d ROUND = _mm256_set1_pd(30423614405477505635920876929024.); // 3 * 2^103
const __m256d SCALE = _mm256_set1_pd(1. / 4503599627370496); // 1 / 2^52
// Multiply and add normalization constant. This forces the multiply
// to be rounded to the correct number of bits.
H = _mm256_fmadd_pd(A, B, ROUND);
// Undo the normalization.
H = _mm256_sub_pd(H, ROUND);
// Recover the bottom half of the product.
L = _mm256_fmsub_pd(A, B, H);
// Correct the scaling of H.
H = _mm256_mul_pd(H, SCALE);
}
这是最简单的方法,也是唯一一种与标量方法有竞争力的方法。最终缩放是可选的,具体取决于您要对输出执行的操作。因此,这只能看作是3条指令。但它也是最不有用的,因为输入和输出都是浮点值。
两个FMA保持融合是至关重要的。这就是快速数学优化可以打破局面的地方。如果第一FMA被分解,则L
不再保证在范围[-2^51, 2^51]
中。如果第二个FMA被打破,L
将完全错误。
范围内的有符号整数:[-251,251]
// A*B = L + H*2^52
// Input: A and B are in the range [-2^51, 2^51]
// Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256i& L, __m256i& H, __m256i A, __m256i B){
const __m256d CONVERT_U = _mm256_set1_pd(6755399441055744); // 3*2^51
const __m256d CONVERT_D = _mm256_set1_pd(1.5);
__m256d l, h, a, b;
// Convert to double
A = _mm256_add_epi64(A, _mm256_castpd_si256(CONVERT_U));
B = _mm256_add_epi64(B, _mm256_castpd_si256(CONVERT_D));
a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);
// Get top half. Convert H to int64.
h = _mm256_fmadd_pd(a, b, CONVERT_U);
H = _mm256_sub_epi64(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));
// Undo the normalization.
h = _mm256_sub_pd(h, CONVERT_U);
// Recover bottom half.
l = _mm256_fmsub_pd(a, b, h);
// Convert L to int64
l = _mm256_add_pd(l, CONVERT_D);
L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_D));
}
在第一个例子的基础上,我们将其与快速double <-> int64
转换技巧的广义版本相结合。
这一个更有用,因为您使用的是整数。但即使使用了快速转换技巧,大部分时间也会花在转换上。幸运的是,如果要多次乘以同一操作数,则可以消除一些输入转换。
范围内的无符号整数:[0,252)
// A*B = L + H*2^52
// Input: A and B are in the range [0, 2^52)
// Output: L and H are in the range [0, 2^52)
void mul52_unsigned(__m256i& L, __m256i& H, __m256i A, __m256i B){
const __m256d CONVERT_U = _mm256_set1_pd(4503599627370496); // 2^52
const __m256d CONVERT_D = _mm256_set1_pd(1);
const __m256d CONVERT_S = _mm256_set1_pd(1.5);
__m256d l, h, a, b;
// Convert to double
A = _mm256_or_si256(A, _mm256_castpd_si256(CONVERT_U));
B = _mm256_or_si256(B, _mm256_castpd_si256(CONVERT_D));
a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);
// Get top half. Convert H to int64.
h = _mm256_fmadd_pd(a, b, CONVERT_U);
H = _mm256_xor_si256(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));
// Undo the normalization.
h = _mm256_sub_pd(h, CONVERT_U);
// Recover bottom half.
l = _mm256_fmsub_pd(a, b, h);
// Convert L to int64
l = _mm256_add_pd(l, CONVERT_S);
L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_S));
// Make Correction
H = _mm256_sub_epi64(H, _mm256_srli_epi64(L, 63));
L = _mm256_and_si256(L, _mm256_set1_epi64x(0x000fffffffffffff));
}
最后我们得到了原来问题的答案。这是在有符号整数解决方案的基础上通过调整转换和添加校正步骤构建的。
但在这一点上,我们有13条指令,其中一半是高延迟指令,还不包括大量的FP <-> int
旁路延迟。因此,这不太可能赢得任何基准。相比之下,64 x 64 -> 128-bit
SIMD乘法可以在16条指令中完成(如果预处理输入,则为14条)
如果舍入模式是向下舍入或舍入到零,则可以省略校正步骤。唯一重要的指令是h = _mm256_fmadd_pd(a, b, CONVERT_U);
。因此,在AVX512上,您可以覆盖该指令的舍入,并保留舍入模式。
最终想法:
值得注意的是,252的操作范围可以通过调整魔术常数来减小。这对于第一个解决方案(浮点解决方案)可能很有用,因为它为您提供了额外的尾数以用于浮点累加。这使您可以像前两个解决方案一样,绕过在int64和double之间不断来回转换的需要。
虽然这里的3个例子不太可能比标量方法更好,但AVX512几乎肯定会打破平衡。尤其是Knights Landing的ADCX和ADOX吞吐量较差。
当然,当AVX512-IFMA问世时,所有这些都是没有意义的。这将一个完整的52 x 52 -> 104-bit
乘积减少为2个指令,并免费提供累加。
执行多字整数运算的一种方法是使用双精度运算。让我们从一些二重乘法代码开始
#include <math.h>
typedef struct {
double hi;
double lo;
} doubledouble;
static doubledouble quick_two_sum(double a, double b) {
double s = a + b;
double e = b - (s - a);
return (doubledouble){s, e};
}
static doubledouble two_prod(double a, double b) {
double p = a*b;
double e = fma(a, b, -p);
return (doubledouble){p, e};
}
doubledouble df64_mul(doubledouble a, doubledouble b) {
doubledouble p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
}
函数two_prod
可以在两条指令中执行整数53bx53b->106b。函数df64_mul
可以做整数106bx106b->106b。
让我们将其与具有整数硬件的整数128bx128b->128b进行比较。
__int128 mul128(__int128 a, __int128 b) {
return a*b;
}
mul128
组件
imul rsi, rdx
mov rax, rdi
imul rcx, rdi
mul rdx
add rcx, rsi
add rdx, rcx
df64_mul
(用gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off
编译)的程序集
vmulsd xmm4, xmm0, xmm2
vmulsd xmm3, xmm0, xmm3
vmulsd xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd xmm3, xmm3, xmm0
vaddsd xmm1, xmm3, xmm1
vaddsd xmm0, xmm1, xmm4
vsubsd xmm4, xmm0, xmm4
vsubsd xmm1, xmm1, xmm4
mul128
进行三次标量乘法和两次标量加法/减法,而df64_mul
进行3次SIMD乘法、1次SIMD FMA和5次SIMD加法/减法。我还没有介绍这些方法,但对我来说,使用每个AVX寄存器4倍的df64_mul
可以优于mul128
(将sd
更改为pd
,将xmm
更改为ymm
),这似乎并非没有道理。
很容易说问题是切换回整数域。但为什么这是必要的呢?您可以在浮点域中执行所有操作。让我们来看看一些例子。我发现用float
进行单元测试比用double
更容易。
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = fma(a, b, -p);
return (doublefloat){p, e};
}
//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05
//hi = 15395627991040, lo = 102575, s = 15395628093615
//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488
因此,我们最终得到了不同的范围,在第二种情况下,误差(e
)甚至是负的,但总和仍然是正确的。我们甚至可以将两个双浮点值x
和y
加在一起(一旦我们知道如何进行双浮点加法,请参阅末尾的代码),得到15395628093615+2178594202488
。没有必要将结果标准化。
但是加法带来了二重算术的主要问题。也就是说,加法/减法是缓慢的,例如128b+128b->128b需要至少11个浮点加法,而对于整数,它只需要两个(add
和adc
)。
因此,如果一个算法重乘法而轻加法,那么用双精度进行多字整数运算可能会获胜。
顺便说一句,C语言足够灵活,可以实现整数完全通过浮点硬件实现的实现。int
可以是24位(来自单个浮点),long
可以是54位。并且long long
可以是106比特(来自双双)。C甚至不需要二的互补,因此整数可以像浮点一样使用负数的有符号幅度。
这里是使用双乘法和加法的C代码(我没有实现除法或其他运算,如sqrt
,但有论文显示了如何实现),以防有人想玩它。看看这是否可以针对整数进行优化会很有趣。
//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>
//#include <float.h>
typedef struct {
float hi;
float lo;
} doublefloat;
typedef union {
float f;
int i;
struct {
unsigned mantisa : 23;
unsigned exponent: 8;
unsigned sign: 1;
};
} float_cast;
void print_float(float_cast a) {
printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %un", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}
void print_doublefloat(doublefloat a) {
float_cast hi = {a.hi};
float_cast lo = {a.lo};
printf("hi: "); print_float(hi);
printf("lo: "); print_float(lo);
}
doublefloat quick_two_sum(float a, float b) {
float s = a + b;
float e = b - (s - a);
return (doublefloat){s, e};
// 3 add
}
doublefloat two_sum(float a, float b) {
float s = a + b;
float v = s - a;
float e = (a - (s - v)) + (b - v);
return (doublefloat){s, e};
// 6 add
}
doublefloat df64_add(doublefloat a, doublefloat b) {
doublefloat s, t;
s = two_sum(a.hi, b.hi);
t = two_sum(a.lo, b.lo);
s.lo += t.hi;
s = quick_two_sum(s.hi, s.lo);
s.lo += t.lo;
s = quick_two_sum(s.hi, s.lo);
return s;
// 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}
doublefloat split(float a) {
//#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
float t = (SPLITTER)*a;
float hi = t - (t - a);
float lo = a - hi;
return (doublefloat){hi, lo};
// 1 mul, 3 add
}
doublefloat split_sse(float a) {
__m128 k = _mm_set1_ps(4097.0f);
__m128 a4 = _mm_set1_ps(a);
__m128 t = _mm_mul_ps(k,a4);
__m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
__m128 lo4 = _mm_sub_ps(a4, hi4);
float tmp[4];
_mm_storeu_ps(tmp, hi4);
float hi = tmp[0];
_mm_storeu_ps(tmp, lo4);
float lo = tmp[0];
return (doublefloat){hi,lo};
}
float mult_sub(float a, float b, float c) {
doublefloat as = split(a), bs = split(b);
//print_doublefloat(as);
//print_doublefloat(bs);
return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
// 4 mul, 4 add, 2 split = 6 mul, 10 add
}
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = mult_sub(a, b, p);
return (doublefloat){p, e};
// 1 mul, one mult_sub
// 7 mul, 10 add
}
float mult_sub2(float a, float b, float c) {
doublefloat as = split(a);
return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}
doublefloat two_sqr(float a) {
float p = a*a;
float e = mult_sub2(a, a, p);
return (doublefloat){p, e};
}
doublefloat df64_mul(doublefloat a, doublefloat b) {
doublefloat p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
//two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add
//or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}
doublefloat df64_sqr(doublefloat a) {
doublefloat p = two_sqr(a.hi);
p.lo += 2*a.hi*a.lo;
return quick_two_sum(p.hi, p.lo);
}
int float2int(float a) {
int M = 0xc00000; //1100 0000 0000 0000 0000 0000
a += M;
float_cast x;
x.f = a;
return x.i - 0x4b400000;
}
doublefloat add22(doublefloat a, doublefloat b) {
float r = a.hi + b.hi;
float s = fabsf(a.hi) > fabsf(b.hi) ?
(((a.hi - r) + b.hi) + b.lo ) + a.lo :
(((b.hi - r) + a.hi) + a.lo ) + b.lo;
return two_sum(r, s);
//11 add
}
int main(void) {
//print_float((float_cast){1.0f});
//print_float((float_cast){-2.0f});
//print_float((float_cast){0.0f});
//print_float((float_cast){3.14159f});
//print_float((float_cast){1.5f});
//print_float((float_cast){3.0f});
//print_float((float_cast){7.0f});
//print_float((float_cast){15.0f});
//print_float((float_cast){31.0f});
//uint64_t t = 0xffffff;
//print_float((float_cast){1.0f*t});
//printf("%" PRId64 " %" PRIx64 "n", t*t,t*t);
/*
float_cast t1;
t1.mantisa = 0x7fffff;
t1.exponent = 0xfe;
t1.sign = 0;
print_float(t1);
*/
//doublefloat z = two_prod(1.0f*t, 1.0f*t);
//print_doublefloat(z);
//double z2 = (double)z.hi + (double)z.lo;
//printf("%.16en", z2);
doublefloat s = {0};
int64_t si = 0;
for(int i=0; i<100000; i++) {
int ai = rand()%0x800, bi = rand()%0x800000;
float a = ai, b = bi;
doublefloat z = two_prod(a,b);
int64_t zi = (int64_t)ai*bi;
//print_doublefloat(z);
//s = df64_add(s,z);
s = add22(s,z);
si += zi;
print_doublefloat(z);
printf("%d %d ", ai,bi);
int64_t h = z.hi;
int64_t l = z.lo;
int64_t t = h+l;
//if(t != zi) printf("%" PRId64 " %" PRId64 "n", h, l);
printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "n", zi, h, l, h+l);
h = s.hi;
l = s.lo;
t = h + l;
//if(si != t) printf("%" PRId64 " %" PRId64 "n", h, l);
if(si > (1LL<<48)) {
printf("overflow after %d iterationsn", i); break;
}
}
print_doublefloat(s);
printf("%" PRId64 "n", si);
int64_t x = s.hi;
int64_t y = s.lo;
int64_t z = x+y;
//int hi = float2int(s.hi);
printf("%" PRId64 " %" PRId64 " %" PRId64 "n", z,x,y);
}
好吧,您当然可以对整数执行FP通道操作。它们总是准确的:虽然有些SSE指令不能保证正确的IEEE-754精度和舍入,但毫无例外,它们是没有整数范围的指令,所以无论如何都不是你正在查看的指令。一句话:加法/减法/乘法在整数域中总是精确的,即使你是在压缩浮点运算中进行的。
至于四精度浮点(>52位尾数),不支持,而且在可预见的未来可能不会支持。只是没有太多的呼吁。它们出现在一些SPARC时代的工作站体系结构中,但老实说,它们只是开发者对如何编写数值稳定算法的不完全理解的绷带,随着时间的推移,它们逐渐消失了。
结果证明,宽整数运算非常不适合SSE。最近,当我实现一个大的整数库时,我真的试图利用它,老实说,这对我没有好处。x86是为多字运算而设计的;你可以在ADC(产生并消耗进位位)和IDIV(只要商不比被除数宽,除数就可以是被除数的两倍,这一限制使它对但的多字除法毫无用处)等操作中看到它。但多字算术本质上是顺序的,SSE本质上是并行的。如果你足够幸运,你的数字有足够的位来放入FP尾数,那么恭喜你。但如果你有大整数,SSE可能不会是你的朋友。