>我有两个双精度,a
和b
,它们都在[0,1]中。出于性能原因,我想要a
和b
的最小/最大值而不进行分支。
鉴于a
和b
都是正数,并且低于 1,是否有一种有效的方法来获得两者的最小值/最大值?理想情况下,我不想分支。
是的,有一种方法可以在没有任何分支的情况下计算两个double
的最大值或最小值。执行此操作C++代码如下所示:
#include <algorithm>
double FindMinimum(double a, double b)
{
return std::min(a, b);
}
double FindMaximum(double a, double b)
{
return std::max(a, b);
}
我敢打赌你以前见过这个。为了避免您不相信这是无分支的,请查看反汇编:
FindMinimum(double, double):
minsd xmm1, xmm0
movapd xmm0, xmm1
ret
FindMaximum(double, double):
maxsd xmm1, xmm0
movapd xmm0, xmm1
ret
这就是你从所有针对 x86 的流行编译器中获得的。使用 SSE2 指令集,特别是minsd
/maxsd
指令,它无分支地计算两个双精度浮点值的最小/最大值。
所有 64 位 x86 处理器都支持 SSE2;这是 AMD64 扩展所必需的。即使是大多数没有 64 位的 x86 处理器也支持 SSE2。它于2000年发布。您必须追溯很长时间才能找到不支持SSE2的处理器。但是如果你这样做了呢?好吧,即使在那里,您也可以在大多数流行的编译器上获得无分支代码:
FindMinimum(double, double):
fld QWORD PTR [esp + 12]
fld QWORD PTR [esp + 4]
fucomi st(1)
fcmovnbe st(0), st(1)
fstp st(1)
ret
FindMaximum(double, double):
fld QWORD PTR [esp + 4]
fld QWORD PTR [esp + 12]
fucomi st(1)
fxch st(1)
fcmovnbe st(0), st(1)
fstp st(1)
ret
fucomi
指令执行比较,设置标志,然后fcmovnbe
指令根据这些标志的值执行条件移动。这一切都是完全无分支的,并且依赖于1995年引入带有奔腾Pro的x86 ISA的说明,自奔腾II以来的所有x86芯片都支持。
这里唯一不会生成无分支代码的编译器是 MSVC,因为它不利用FCMOVxx
指令。相反,你会得到:
double FindMinimum(double, double) PROC
fld QWORD PTR [a]
fld QWORD PTR [b]
fcom st(1) ; compare "b" to "a"
fnstsw ax ; transfer FPU status word to AX register
test ah, 5 ; check C0 and C2 flags
jp Alt
fstp st(1) ; return "b"
ret
Alt:
fstp st(0) ; return "a"
ret
double FindMinimum(double, double) ENDP
double FindMaximum(double, double) PROC
fld QWORD PTR [b]
fld QWORD PTR [a]
fcom st(1) ; compare "b" to "a"
fnstsw ax ; transfer FPU status word to AX register
test ah, 5 ; check C0 and C2 flags
jp Alt
fstp st(0) ; return "b"
ret
Alt:
fstp st(1) ; return "a"
ret
double FindMaximum(double, double) ENDP
请注意分支JP
指令(如果设置了奇偶校验位,则跳转)。FCOM
指令用于进行比较,这是基本 x87 FPU 指令集的一部分。不幸的是,这会在 FPU 状态词中设置标志,因此为了在这些标志上进行分支,需要提取它们。这就是FNSTSW
指令的目的,它将 x87 FPU 状态字存储到通用AX
寄存器(它也可以存储到内存中,但是......为什么?然后,代码TEST
相应的位,并相应地进行分支以确保返回正确的值。除了分支之外,检索 FPU 状态字也会比较慢。这就是奔腾Pro引入FCOM
说明的原因。
但是,您不太可能通过使用位摆动操作来确定最小值/最大值来提高任何此代码的速度。有两个基本原因:
生成低效代码的唯一编译器是 MSVC,并且没有好方法可以强制它生成您想要的指令。尽管 MSVC 中支持 32 位 x86 目标的内联程序集,但在寻求性能改进时,这是愚蠢的差事。我还要引用自己的话:
内联程序集以相当显著的方式破坏优化器,因此除非您在内联程序集中编写大量代码,否则不太可能有实质性的净性能提升。此外,Microsoft的内联程序集语法非常有限。它在很大程度上牺牲了灵活性和简单性。特别是,无法指定输入值,因此您只能将输入从内存加载到寄存器中,并且调用方被迫将输入从寄存器溢出到内存以准备。这创造了一种现象,我喜欢称之为"一大堆洗牌'goin'on",或者简称为"慢代码"。在可以接受慢速代码的情况下,不要下降到内联程序集。因此,最好(至少在 MSVC 上)弄清楚如何编写 C/C++ 源代码来说服编译器发出您想要的目标代码。即使您只能接近理想的输出,这仍然比您使用内联装配所付出的代价要好得多。
为了访问浮点值的原始位,您必须执行域转换,从浮点到整数,然后再回到浮点。这很慢,尤其是在没有 SSE2 的情况下,因为从 x87 FPU 获取值到 ALU 中的通用整数寄存器的唯一方法是通过内存间接获取。
如果您无论如何都想采用此策略(例如,对其进行基准测试),则可以利用浮点值根据其IEEE 754表示形式按字典顺序排序的事实,符号位除外。因此,由于您假设两个值都是正数:
FindMinimumOfTwoPositiveDoubles(double a, double b):
mov rax, QWORD PTR [a]
mov rdx, QWORD PTR [b]
sub rax, rdx ; subtract bitwise representation of the two values
shr rax, 63 ; isolate the sign bit to see if the result was negative
ret
FindMaximumOfTwoPositiveDoubles(double a, double b):
mov rax, QWORD PTR [b] ; reverse order of parameters
mov rdx, QWORD PTR [a] ; / for the SUB operation
sub rax, rdx
shr rax, 63
ret
或者,要避免内联装配:
bool FindMinimumOfTwoPositiveDoubles(double a, double b)
{
static_assert(sizeof(a) == sizeof(uint64_t),
"A double must be the same size as a uint64_t for this bit manipulation to work.");
const uint64_t aBits = *(reinterpret_cast<uint64_t*>(&a));
const uint64_t bBits = *(reinterpret_cast<uint64_t*>(&b));
return ((aBits - bBits) >> ((sizeof(uint64_t) * CHAR_BIT) - 1));
}
bool FindMaximumOfTwoPositiveDoubles(double a, double b)
{
static_assert(sizeof(a) == sizeof(uint64_t),
"A double must be the same size as a uint64_t for this bit manipulation to work.");
const uint64_t aBits = *(reinterpret_cast<uint64_t*>(&a));
const uint64_t bBits = *(reinterpret_cast<uint64_t*>(&b));
return ((bBits - aBits) >> ((sizeof(uint64_t) * CHAR_BIT) - 1));
}
请注意,此实现存在严重的警告。特别是,如果两个浮点值具有不同的符号,或者两个值都是负数,它将中断。如果两个值都是负数,则可以修改代码以翻转它们的符号,进行比较,然后返回相反的值。为了处理两个值具有不同符号的情况,可以添加代码来检查符号位。
// ...
// Enforce two's-complement lexicographic ordering.
if (aBits < 0)
{
aBits = ((1 << ((sizeof(uint64_t) * CHAR_BIT) - 1)) - aBits);
}
if (bBits < 0)
{
bBits = ((1 << ((sizeof(uint64_t) * CHAR_BIT) - 1)) - bBits);
}
// ...
处理负零也将是一个问题。IEEE 754 表示 +0.0 等于 −0.0,因此您的比较函数必须决定是将这些值视为不同的值,还是向比较例程添加特殊代码,以确保将负零和正零视为等效。
添加所有这些特殊情况的代码肯定会降低性能,以至于我们将通过朴素的浮点比较实现收支平衡,并且很可能最终会变慢。