快速的大格数平方计算



为了加速我的大整数除法,我需要加速大整数的y = x^2操作,这些大整数表示为无符号DWORDs的动态数组。说明:

DWORD x[n+1] = { LSW, ......, MSW };
  • 其中n+1为已使用DWORDs数
  • x = x[0]+x[1]<<32 + ... x[N]<<32*(n)的so值

问题是:我如何在没有精度损失的情况下尽可能快地计算y = x^2?-使用c++ 和整数运算(带进位的32位)。

我目前的方法是应用乘法y = x*x,避免多次乘法。

例如:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

为简单起见,我重写一下:

x = x0+ x1 + x2 + ... + xn

其中index表示数组内的地址,因此:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)
y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

仔细观察后,很明显,几乎所有xi*xj都出现了两次(而不是第一次和最后一次),这意味着N*N乘法可以用(N+1)*(N/2)乘法代替。P.S.32bit*32bit = 64bit,因此每个mul+add操作的结果都被处理为64+1 bit

有没有更好的方法来快速计算这个?我在搜索过程中发现的都是sqrt算法,而不是sqr…

快速sqr

! !注意,我代码中的所有数字都是MSW首先,…不像上面的测试(有LSW首先为方程的简单性,否则将是一个索引混乱)。

当前功能fsqr实现

void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}

使用乘法

(感谢Calpis)

我实现了Karatsuba乘法,但是结果比使用简单的O(N^2)乘法要慢得多,可能是因为我看不到任何方法可以避免的可怕的递归。它的权衡必须是非常大的数字(大于数百位)…但即便如此,仍有大量的记忆传输。是否有一种方法来避免递归调用(非递归变体,…几乎所有递归算法都可以这样做)。尽管如此,我还是会试着调整一下,看看会发生什么(避免规范化等)。(也可能是代码中的一些愚蠢的错误)。无论如何,在解决了x*x情况下的Karatsuba之后,性能并没有多少提高。

优化的Karatsuba乘法

y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits性能测试:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication
x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]
x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

在对Karatsuba进行优化后,代码比以前快得多。不过,对于较小的数字,它的速度略低于我的O(N^2)乘法的一半。对于较大的数字,它的速度比布斯乘法的复杂性更快。乘法的阈值大约是32*98位,而sqr的阈值大约是32*389位,所以如果输入比特的总和超过这个阈值,那么Karatsuba乘法将用于加速乘法,对于sqr也是如此。

顺便说一句,优化包括:

  • 使用太大的递归参数最小化堆垃圾
  • 使用带进位的32位ALU来避免任何十进制算术(+,-)。
  • 忽略0*yx*00*0案例
  • 重新格式化输入x,y数字大小,以避免重新分配
  • z1 = (x0 + x1)*(y0 + y1)实现模乘法以最小化递归

修改Schönhage-Strassen乘法到sqr的实现

我测试了FFT的使用和NTT变换以加快SQR计算。结果如下:

  1. FFT

    失去精度,因此需要高精度的复数。这实际上大大降低了速度,所以没有加速。结果不精确(可能被错误地舍入),所以FFT

    当前无法使用
  2. NTT

    NTTDFT这样就不会出现精度损失。对无符号整数modpow, modmul, modaddmodsub需要模运算。

    我使用DWORD(32位无符号整数)。NTT由于溢出问题,输入/输出矢量大小有限!!对于32位模块算法,N仅限于(2^32)/(max(input[])^2),因此bigint必须分成更小的块(我使用BYTES,因此处理的bigint的最大大小为

    )
    (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
    

    sqr只使用1xNTT + 1xINTT而不是2xNTT + 1xINTT进行乘法运算,但是NTT在我的实现(mulsqr)中,使用率太慢,阈值大小太大,无法实际使用。

    甚至可能超过溢出限制,因此应该使用64位模块算术,这可以使事情变得更慢。所以NTT对我来说也是没用的。

一些测量:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

我实现:

void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;
// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;
//NTT
fourier_NTT ntt;
ntt.NTT(yy,xx,n);    // init NTT for n
// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);
//INTT
ntt.INTT(xx, yy);
//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}
// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q  =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j]    )&0x000000FF; j++;
dat[i] = q;
}
#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}

结论

对于较小的数字,它是我的快速sqr方法的最佳选择,之后阈值Karatsuba乘法更好。但我仍然认为,我们应该忽略了一些微不足道的东西。还有别的主意吗?

NTT优化

大规模优化后(主要是NTT)):堆栈溢出问题模块算法和NTT(有限域DFT)优化。

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

所以现在NTT乘法终于比Karatsuba快了约1500*32位阈值后。

一些测量和错误发现

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

我发现我的KaratsubaLSBbignum的DWORD节段。当我研究完后,我会更新代码…

同样,在进一步NTT之后优化后的阈值改变了,所以对于NTT sqr操作数310*32 bits = 9920 bits,对于NTT mulresult1396*32 bits = 44672 bits(操作数位和)

Karatsuba代码修复感谢@greybeard

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?
if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?
if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0
if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}
if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
_mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1
qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1
_mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)
if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2
if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2
qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2
DWORD ccc=0;
if (alu.cy)
ccc++;    // Handle carry from last operation
if (cx || cy)
ccc++;    // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}
delete[] q;
#undef _add
#undef _sub
}
//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
//    _normalize();
}
//---------------------------------------------------------------------------

我的arbnum号码表示:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
  • dat[siz]是尾号。LSDW表示最不显著DWORD。
  • expdat[0]
  • 的MSB指数。
  • 第一个非零位出现在尾数中!!

    // |-----|---------------------------|---------------|------|
    // | sig | MSB      mantisa      LSB |   exponent    | bits |
    // |-----|---------------------------|---------------|------|
    // | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
    // | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
    // | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
    // | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
    // |-----|---------------------------|---------------|------|
    

如果我正确理解你的算法,似乎O(n^2)其中n是位数。

你看过Karatsuba算法吗?它使用分治法加速乘法。也许值得一看。

你的问题很好,谢谢!

决定从头开始为您实现一个巨大的c++解决方案,基于数论变换(NTT)和离散傅立叶变换。

提前说明一下,我的FFT/NTT代码实现了330x对于数组大小为2^16个32位单词的情况,与幼稚的学校级乘法相比,在2核旧笔记本电脑上的加速。即使是大于2^20大小的更大的数组也会提供数百万倍的加速。

将一个32位的2^22个字的数平方(即400万个字)需要7秒13秒在我的FFT上,在旧的2GHz 2核笔记本电脑上,只有SSE2。

需要提醒的是,FFT和NTT给出的乘法时间为O(N * Log(N)),而朴素的年级算法给出的乘法时间为O(N^2)。这就是为什么我有这么大的加速在前面的段落中描述。

和代码在这篇文章中都有很好的描述,主要是我在写下面的代码时受到了这篇文章的启发。另一篇好文章是Nayuki的NTT文章。

我确信,对于相当大的数字,这两种变换将击败任何其他方法,如Karatsuba。

除了文章中描述的基本方法外,我还进行了许多优化:

  1. 对于NTT计算我自己的本原根和模的集合。用了最接近2^62的那个

  2. 几乎在NTT和FFT计算的每个循环中都使用多线程。通过OpenMP。

  3. 对于平方,我使用了2个变换而不是3个(用于乘法)。这给了33%的速度提升。

  4. For NTT在计算模量时对所有数组使用Montgomery Reduction。这给了大约2 -3倍的加速。

  5. 使用constexpr函数和值和模板编程在任何地方,我可以。在可能的情况下,将运行时值减少为编译时值可以大大提高速度。

  6. 重新设计的交换/洗牌功能,在每次FFT/NTT转换开始时使用。使用预计算表和缓存来重用以前的结果。还做了块交换,使缓存友好的读/写。此外,位旋转不是在循环中完成的,而是使用预先计算的位表完成的。

  7. transform主循环将W乘法器的计算分解为单独的循环,并进行预计算/缓存。这给了大约2倍的加速。

  8. 使用Intel SIMD指令集,目前为SSE2和AVX。这些仅用于FFT,因为NTT使用128位整数除法和乘法以及带进位的加法/子运算,这些在SIMD中不可用。另外,对于FFT中的SIMD,我设计了循环展开,在std::array<>中使用特殊的缓存友好存储复数。

  9. NTT/FFT乘法与naive的时间/性能测量。

  10. 对FFT内部错误率进行了分析。提醒NTT没有任何错误

我的代码是自包含的,如果你编译+运行它,它将运行测试测量速度。在测试函数中,你可以看到如何使用我的库。测试运行FFT/NTT/朴素乘法,测量时间并比较是否所有乘法结果都是正确的,即等于朴素版本。

注意:无论我如何努力通过SIMD加速FFT,但我的NTT是如此优化,它是1.3-1.8x比FFT快几倍。正如你所知,FFT给出的误差会随着数字的增大而增大。如果考虑到我的NTT更快的事实,那么NTT是你唯一的选择!

似乎FFT只能用于像2^16 32位字这样的数组大小,不能再多了,然后错误大小变得至关重要并破坏最终结果。或者您可以减少输入32位数字的大小,到10-12位,这有助于减少错误,但您不能使用大于2^18的数组大小的临界错误。你必须通过实验计算误差大小来找出最好的方法。

代码可以在CLang/MSVC/GCC中编译。也许还有其他的编译器。它没有任何外部库依赖,可能除了OpenMP库,它通常与编译器一起发货。只有原始根(NTT模)的计算需要Boost库,但只适用于MSVC,并且只使用其中的128位整数。

CODE GOES HERE. 仅仅因为代码大小是65 KB,我不能将它内联到这篇文章中,因为StackOverflow的帖子大小限制是30 000个符号。因此,我提供我的代码下面Github Gist链接。点击Try it online!链接在GodBolt的在线服务器上运行我的代码。

上网试试!

Github Gist源代码

控制台输出示例:

Using SIMD SSE2
Test FindNttMod 
FindNttEntry<T>{.k = 57, .c = 29, .p = 4179340454199820289, .g = 3, .root = 68630377364883, .plog2 = 61.86},
FindNttEntry<T>{.k = 54, .c = 177, .p = 3188548536178311169, .g = 7, .root = 3055434446054240334, .plog2 = 61.47},
FindNttEntry<T>{.k = 54, .c = 163, .p = 2936346957045563393, .g = 3, .root = 83050791888939419, .plog2 = 61.35},
FindNttEntry<T>{.k = 55, .c = 69, .p = 2485986994308513793, .g = 5, .root = 1700750308946223057, .plog2 = 61.11},
FindNttEntry<T>{.k = 54, .c = 127, .p = 2287828610704211969, .g = 3, .root = 878887558841786394, .plog2 = 60.99},
FindNttEntry<T>{.k = 55, .c = 57, .p = 2053641430080946177, .g = 7, .root = 640559856471874596, .plog2 = 60.83},
FindNttEntry<T>{.k = 56, .c = 27, .p = 1945555039024054273, .g = 5, .root = 1613915479851665306, .plog2 = 60.75},
FindNttEntry<T>{.k = 53, .c = 161, .p = 1450159080013299713, .g = 3, .root = 359678689516082930, .plog2 = 60.33},
FindNttEntry<T>{.k = 53, .c = 143, .p = 1288029493427961857, .g = 3, .root = 531113314168589713, .plog2 = 60.16},
FindNttEntry<T>{.k = 55, .c = 35, .p = 1261007895663738881, .g = 6, .root = 397650301651152680, .plog2 = 60.13},
0.025 sec
Test CompareNttMultWithReg 
Time NTT 0.035 FFT 0.081 Reg 11.614 Boost_NTT 333.588x (FFT 142.644)
Swap 0.776 (Slow 0.000) ToMontg 0.079 Main 3.056 (0.399, 2.656) Invert 0.000 All 3.911
MidMul 0.110
Swap 0.510 (Slow 0.000) ToMontg 0.000 Main 2.535 (0.336, 2.198) Invert 0.094 All 3.139
AssignComplex 0.495
Swap 1.373 FromComplex 0.309 Main 4.875 (0.382, 4.493) Invert 0.000 ToComplex 0.224 All 6.781
MidMul 0.147
Swap 1.106 FromComplex 0.296 Main 4.209 (0.277, 3.931) Invert 0.166 ToComplex 0.199 All 5.975
Round 0.143
Time NTT 7.457 FFT 14.097 Boost_NTT 1.891x
Run Time: 33.719 sec

如果你想写一个更好的指数,你可能必须用汇编语言写。这是golang的代码。

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s

最新更新