int SWAR(unsigned int i)
{
i = i - ((i >> 1) & 0x55555555);
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
我看过这段代码,它计算 32 位整数的位数等于1
,我注意到它的性能比__builtin_popcount
好,但我无法理解它的工作方式。
有人可以详细说明此代码的工作原理吗?
好的,让我们逐行浏览代码:
第 1 行:
i = i - ((i >> 1) & 0x55555555);
首先,常量0x55555555
的意义在于,使用Java/GCC风格的二进制文字表示法编写),
0x55555555 = 0b01010101010101010101010101010101
也就是说,它所有的奇数位(将最低位计为位 1 = 奇数)都1
,并且所有偶数位都0
。
因此,表达式((i >> 1) & 0x55555555)
将i
位右移 1,然后将所有偶数位设置为零。 (等价地,我们可以先用& 0xAAAAAAAA
将所有i
的奇数位设置为零,然后将结果向右移动一位。 为方便起见,我们将此中间值称为j
。
当我们从原始i
中减去这个j
时会发生什么? 好吧,让我们看看如果i
只有两个位会发生什么:
i j i - j
----------------------------------
0 = 0b00 0 = 0b00 0 = 0b00
1 = 0b01 0 = 0b00 1 = 0b01
2 = 0b10 1 = 0b01 1 = 0b01
3 = 0b11 1 = 0b01 2 = 0b10
嘿! 我们已经设法计算了两位数字的位数!
好的,但是如果i
设置了两个以上的位怎么办? 事实上,很容易检查i - j
的最低两位是否仍由上表给出,第三位和第四位,第五位和第六位也是如此,等等。 特别:
尽管有
>> 1
,i - j
的最低两位不受i
的第三位或更高位的影响,因为它们会被& 0x55555555
屏蔽j
;因为
j
的最低两位永远不会有比i
更大的数值,减法永远不会借用i
的第三位:因此,i
的最低两位也不能影响i - j
的第三位或更高位。
事实上,通过重复相同的参数,我们可以看到,这一行的计算实际上将上表应用于并行i
的 16 个两位块中的每一个。 也就是说,在执行此行后,新值i
的最低两位现在将包含原始值i
中相应位之间设置的位数,接下来的两个位也是如此,依此类推。
第 2 行:
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
与第一行相比,这一行非常简单。 首先,请注意
0x33333333 = 0b00110011001100110011001100110011
因此,i & 0x33333333
采用上面计算的两位计数并每隔一秒丢弃一次,而(i >> 2) & 0x33333333
在将i
右移两位后执行相同的操作。 然后我们将结果加在一起。
因此,实际上,此行的作用是获取原始输入的最低两位和第二低两位的位数,在前一行上计算,并将它们相加以得出输入的最低四位的位数。 同样,它对输入的所有 8 个四位块(= 十六进制数字)并行执行此操作。
第 3 行:
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
好了,这是怎么回事?
嗯,首先,(i + (i >> 4)) & 0x0F0F0F0F
与前一行完全相同,只是它将相邻的四位位数相加,以给出输入的每个八位块(即字节)的位数。 (在这里,与上一行不同,我们可以将&
移到加法之外,因为我们知道八位位数永远不会超过 8,因此可以容纳在四位而不会溢出。
现在我们有一个由四个 8 位字节组成的 32 位数字,每个字节包含原始输入的该字节中的 1 位数字。 (我们将这些字节称为A
、B
、C
和D
。 那么当我们将这个值(我们称之为k
)乘以0x01010101
时会发生什么?
好吧,自0x01010101 = (1 << 24) + (1 << 16) + (1 << 8) + 1
以来,我们有:
k * 0x01010101 = (k << 24) + (k << 16) + (k << 8) + k
因此,结果的最高字节最终为:
- 由于
k
项,其原始值加上 - 由于
k << 8
项,下一个较低字节的值加上 - 由于
k << 16
项,第二个较低字节的值加上 - 由于
k << 24
项,第四个也是最低字节的值。
(一般来说,也可以有来自较低字节的进位,但由于我们知道每个字节的值最多为 8,我们知道加法永远不会溢出并创建进位。
也就是说,k * 0x01010101
的最高字节最终是输入的所有字节的位计数之和,即 32 位输入数字的总位计数。 然后,最后的>> 24
只是将此值从最高字节向下移动到最低字节。
只需将0x01010101
更改为0x0101010101010101
,将>> 24
更改为>> 56
,即可轻松扩展到 64 位整数。事实上,相同的方法甚至适用于 128 位整数;然而,256 位需要添加一个额外的移位/添加/掩码步骤,因为数字 256 不再完全适合 8 位字节。
我更喜欢这个,它更容易理解。
x = (x & 0x55555555) + ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);
x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff);
x = (x & 0x0000ffff) + ((x >> 16) &0x0000ffff);
这是对Ilamari回答的评论。 由于格式问题,我将其作为答案:
第 1 行:
i = i - ((i >> 1) & 0x55555555); // (1)
这一行是从这句更容易理解的行派生出来的:
i = (i & 0x55555555) + ((i >> 1) & 0x55555555); // (2)
如果我们打电话
i = input value
j0 = i & 0x55555555
j1 = (i >> 1) & 0x55555555
k = output value
我们可以重写(1)和(2)以使解释更清晰:
k = i - j1; // (3)
k = j0 + j1; // (4)
我们想证明 (3) 可以从 (4) 派生。
i
可以写成其偶数位和奇数位的相加(将最低位计为位 1 = 奇数):
i = iodd + ieven =
= (i & 0x55555555) + (i & 0xAAAAAAAA) =
= (i & modd) + (i & meven)
由于meven
面具清除了最后一点i
, 最后一个等式可以这样写:
i = (i & modd) + ((i >> 1) & modd) << 1 =
= j0 + 2*j1
那是:
j0 = i - 2*j1 (5)
最后,将 (5) 替换为 (4) 我们实现 (3):
k = j0 + j1 = i - 2*j1 + j1 = i - j1
这是对耶尔回答的解释:
int SWAR(unsigned int i) {
i = (i & 0x55555555) + ((i >> 1) & 0x55555555); // A
i = (i & 0x33333333) + ((i >> 2) & 0x33333333); // B
i = (i & 0x0f0f0f0f) + ((i >> 4) & 0x0f0f0f0f); // C
i = (i & 0x00ff00ff) + ((i >> 8) & 0x00ff00ff); // D
i = (i & 0x0000ffff) + ((i >> 16) &0x0000ffff); // E
return i;
}
让我们用A行作为我解释的基础。
i = (i & 0x55555555) + ((i >> 1) & 0x55555555)
让我们重命名上面的表达式,如下所示:
i = (i & mask) + ((i >> 1) & mask)
= A1 + A2
首先,不要将i
视为 32 位,而是 16 组的数组,每个组 2 位。A1
是大小为 16 的 count 数组,每个组包含1
s 的计数,位于i
中相应组的最右侧位:
i = yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx
mask = 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01
i & mask = 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x
同样,A2
正在"计数"i
中每个组的最左边位。请注意,我可以将A2 = (i >> 1) & mask
重写为A2 = (i & mask2) >> 1
:
i = yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx
mask2 = 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
(i & mask2) = y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0
(i & mask2) >> 1 = 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y
(请注意,mask2 = 0xaaaaaaaa
)
因此,A1 + A2
将A1
数组和A2
数组的计数相加,得到一个包含 16 个组的数组,每个组现在包含每个组中的位数计数。
转到 B 行,我们可以按如下方式重命名该行:
i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
= (i & mask) + ((i >> 2) & mask)
= B1 + B2
B1 + B2
遵循与以前A1 + A2
相同的"形式"。i
不再视为 16 组 2 位,而是 8 组 4 位。与之前类似,B1 + B2
将B1
和B2
的计数相加,其中B1
是组右侧的1
s 计数,B2
是组左侧的计数。 因此,B1 + B2
是每个组中的位数。
C 到 E 行现在变得更容易理解:
int SWAR(unsigned int i) {
// A: 16 groups of 2 bits, each group contains number of 1s in that group.
i = (i & 0x55555555) + ((i >> 1) & 0x55555555);
// B: 8 groups of 4 bits, each group contains number of 1s in that group.
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
// C: 4 groups of 8 bits, each group contains number of 1s in that group.
i = (i & 0x0f0f0f0f) + ((i >> 4) & 0x0f0f0f0f);
// D: 2 groups of 16 bits, each group contains number of 1s in that group.
i = (i & 0x00ff00ff) + ((i >> 8) & 0x00ff00ff);
// E: 1 group of 32 bits, containing the number of 1s in that group.
i = (i & 0x0000ffff) + ((i >> 16) &0x0000ffff);
return i;
}