我的问题是"如何使这段代码更快(学习最佳实践)?"的后续问题,该文章已被搁置(无赖)。问题是优化带有浮点数的数组上的循环,这些浮点数被测试它们是否位于给定的间隔内。数组中匹配元素的索引将存储在提供的结果数组中。
该测试包括两个条件(小于上限阈值和大于下限阈值)。测试的明显代码是 if( elem <= upper && elem >= lower ) ...
.我观察到分支(包括短路运算符中涉及的隐式分支&&)比第二个比较昂贵得多。我想出的如下。它比朴素的实现快约 20%-40%,比我预期的要快。它使用布尔值是整数类型的事实。条件测试结果用作两个结果数组的索引。其中只有一个将包含所需的数据,另一个可以丢弃。这将程序结构替换为数据结构和计算。
我对更多优化的想法感兴趣。欢迎"技术黑客"(此处提供的那种)。我还对现代C++是否可以提供更快的方法感兴趣,例如,通过使编译器能够创建并行运行代码。想想访客模式/函子。对单个 srcArr 元素的计算几乎是独立的,除了结果数组中索引的顺序取决于测试源数组元素的顺序。我会稍微放宽要求,以便结果数组中报告的匹配索引的顺序无关紧要。有人能想出一个快速的方法吗?
这是函数的源代码。下面是一个支持主干。GCC 需要 -std=C++11,因为 chrono。VS 2013 express也能够编译它(并且创建的代码比gcc -O3快40%)。
#include <cstdlib>
#include <iostream>
#include <chrono>
using namespace std;
using namespace std::chrono;
/// Check all elements in srcArr whether they lie in
/// the interval [lower, upper]. Store the indices of
/// such elements in the array pointed to by destArr[1]
/// and return the number of matching elements found.
/// This has been highly optimized, mainly to avoid branches.
int findElemsInInterval( const float srcArr[], // contains candidates
int **const destArr, // two arrays to be filled with indices
const int arrLen, // length of each array
const float lower, const float upper // interval
)
{
// Instead of branching, use the condition
// as an index into two distinct arrays. We need to keep
// separate indices for both those arrays.
int destIndices[2];
destIndices[0] = destIndices[1] = 0;
for( int srcInd=0; srcInd<arrLen; ++srcInd )
{
// If the element is inside the interval, both conditions
// are true and therefore equal. In all other cases
// exactly one condition is true so that they are not equal.
// Matching elements' indices are therefore stored in destArr[1].
// destArr[0] is a kind of a dummy (it will incidentally contain
// indices of non-matching elements).
// This used to be (with a simple int *destArr)
// if( srcArr[srcInd] <= upper && srcArr[srcInd] >= lower) destArr[destIndex++] = srcInd;
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[isInInterval][destIndices[isInInterval]++] = srcInd;
}
return destIndices[1]; // the number of elements in the results array
}
int main(int argc, char *argv[])
{
int arrLen = 1000*1000*100;
if( argc > 1 ) arrLen = atol(argv[1]);
// destArr[1] will hold the indices of elements which
// are within the interval.
int *destArr[2];
// we don't check destination boundaries, so make them
// the same length as the source.
destArr[0] = new int[arrLen];
destArr[1] = new int[arrLen];
float *srcArr = new float[arrLen];
// Create always the same numbers for comparison (don't srand).
for( int srcInd=0; srcInd<arrLen; ++srcInd ) srcArr[srcInd] = rand();
// Create an interval in the middle of the rand() spectrum
float lowerLimit = RAND_MAX/3;
float upperLimit = lowerLimit*2;
cout << "lower = " << lowerLimit << ", upper = " << upperLimit << endl;
int numInterval;
auto t1 = high_resolution_clock::now(); // measure clock time as an approximation
// Call the function a few times to get a longer run time
for( int srcInd=0; srcInd<10; ++srcInd )
numInterval = findElemsInInterval( srcArr, destArr, arrLen, lowerLimit, upperLimit );
auto t2 = high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( t2 - t1 ).count();
cout << numInterval << " elements found in " << duration << " milliseconds. " << endl;
return 0;
}
<= x && x
你可以尝试类似的东西
const float radius = (b-a)/2;
if( fabs( x-(a+radius) ) < radius )
...
将检查减少到一个条件。
我看到大约 10% 的加速:
int destIndex = 0; // replace destIndices
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[1][destIndex] = srcInd;
destIndex += isInInterval;
消除一对输出数组。 相反,如果您想保留结果,则只将"写入的数字"提前 1,否则只需继续覆盖"超过末尾的一个"索引。
也就是说,retval[destIndex]=curIndex; destIndex+= isInArray;
- 更好的连贯性和更少的内存浪费。
编写两个版本:一个支持固定数组长度(例如 1024 或其他),另一个支持运行时参数。 使用 template
参数删除代码重复。 假设长度小于该常量。
具有函数返回大小和 RVO std::array<unsigned, 1024>
。
编写一个合并结果的包装函数(创建所有结果,然后合并它们)。 然后将平行模式库抛出问题(因此结果在平行中计算)。
如果您允许自己使用 SSE(或更好的 AVX)指令集进行矢量化,您可以一次执行 4/8 比较,执行此操作两次,"和"结果,然后检索 4 个结果(-1 或 0)。同时,这展开了循环。
// Preload the bounds
__m128 lo= _mm_set_ps(lower);
__m128 up= _mm_set_ps(upper);
int srcIndex, dstIndex= 0;
for (srcInd= 0; srcInd + 3 < arrLen; )
{
__m128 src= _mm_load_ps(&srcArr[srcInd]); // Load 4 values
__m128 tst= _mm_and_ps(_mm_cmple_ps(src, lo), _mm_cmpge_ps(src, up)); // Test
// Copy the 4 indexes with conditional incrementation
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[0];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[1];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[2];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[3];
}
注意:未选中的代码。