优化具有两个条件的数组元素的比较;C++抽象机制



我的问题是"如何使这段代码更快(学习最佳实践)?"的后续问题,该文章已被搁置(无赖)。问题是优化带有浮点数的数组上的循环,这些浮点数被测试它们是否位于给定的间隔内。数组中匹配元素的索引将存储在提供的结果数组中。

该测试包括两个条件(小于上限阈值和大于下限阈值)。测试的明显代码是 if( elem <= upper && elem >= lower ) ... .我观察到分支(包括短路运算符中涉及的隐式分支&&)比第二个比较昂贵得多。我想出的如下。它比朴素的实现快约 20%-40%,比我预期的要快。它使用布尔值是整数类型的事实。条件测试结果用作两个结果数组的索引。其中只有一个将包含所需的数据,另一个可以丢弃。这将程序结构替换为数据结构和计算。

我对更多优化的想法感兴趣。欢迎"技术黑客"(此处提供的那种)。我还对现代C++是否可以提供更快的方法感兴趣,例如,通过使编译器能够创建并行运行代码。想想访客模式/函子。对单个 srcArr 元素的计算几乎是独立的,除了结果数组中索引的顺序取决于测试源数组元素的顺序。我会稍微放宽要求,以便结果数组中报告的匹配索引的顺序无关紧要。有人能想出一个快速的方法吗?

这是函数的源代码。下面是一个支持主干。GCC 需要 -std=C++11,因为 chrono。VS 2013 express也能够编译它(并且创建的代码比gcc -O3快40%)。

#include <cstdlib>
#include <iostream>
#include <chrono>
using namespace std;
using namespace std::chrono;
/// Check all elements in srcArr whether they lie in 
/// the interval [lower, upper]. Store the indices of
/// such elements in the array pointed to by destArr[1]
/// and return the number of matching elements found.
/// This has been highly optimized, mainly to avoid branches.
int findElemsInInterval(    const float srcArr[],   // contains candidates
                            int **const destArr,    // two arrays to be filled with indices
                            const int arrLen,       // length of each array
                            const float lower, const float upper // interval
                        )
{
    // Instead of branching, use the condition 
    // as an index into two distinct arrays. We need to keep
    // separate indices for both those arrays.
    int destIndices[2];     
    destIndices[0] = destIndices[1] = 0;
    for( int srcInd=0; srcInd<arrLen; ++srcInd )
    {
        // If the element is inside the interval, both conditions
        // are true and therefore equal. In all other cases 
        // exactly one condition is true so that they are not equal.
        // Matching elements' indices are therefore stored in destArr[1].
        // destArr[0] is a kind of a dummy (it will incidentally contain
        // indices of non-matching elements).
        // This used to be (with a simple int *destArr)
        // if( srcArr[srcInd] <= upper && srcArr[srcInd] >= lower) destArr[destIndex++] = srcInd;
        int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
        destArr[isInInterval][destIndices[isInInterval]++] = srcInd;    
    }
    return destIndices[1];  // the number of elements in the results array 
}

int main(int argc, char *argv[])
{
    int arrLen = 1000*1000*100;
    if( argc > 1 ) arrLen = atol(argv[1]);
    // destArr[1] will hold the indices of elements which
    // are within the interval.
    int *destArr[2];
    // we don't check destination boundaries, so make them 
    // the same length as the source.
    destArr[0] = new int[arrLen];   
    destArr[1] = new int[arrLen];
    float *srcArr = new float[arrLen];
    // Create always the same numbers for comparison (don't srand).
    for( int srcInd=0; srcInd<arrLen; ++srcInd ) srcArr[srcInd] = rand();
    // Create an interval in the middle of the rand() spectrum
    float lowerLimit = RAND_MAX/3;
    float upperLimit = lowerLimit*2;
    cout << "lower = " << lowerLimit << ", upper = " << upperLimit << endl;
    int numInterval; 
    auto t1 = high_resolution_clock::now(); // measure clock time as an approximation
    // Call the function a few times to get a longer run time
    for( int srcInd=0; srcInd<10; ++srcInd )  
        numInterval = findElemsInInterval( srcArr, destArr, arrLen, lowerLimit, upperLimit );
    auto t2 = high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( t2 - t1 ).count();
    cout << numInterval << " elements found in " << duration << " milliseconds. " << endl;
    return 0;
}
想到将

<= x && x

你可以尝试类似的东西

const float radius = (b-a)/2;
if( fabs( x-(a+radius) ) < radius )
    ...

将检查减少到一个条件。

我看到大约 10% 的加速:

int destIndex = 0;  // replace destIndices
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[1][destIndex] = srcInd;
destIndex += isInInterval;

消除一对输出数组。 相反,如果您想保留结果,则只将"写入的数字"提前 1,否则只需继续覆盖"超过末尾的一个"索引。

也就是说,retval[destIndex]=curIndex; destIndex+= isInArray; - 更好的连贯性和更少的内存浪费。

编写两个版本:一个支持固定数组长度(例如 1024 或其他),另一个支持运行时参数。 使用 template 参数删除代码重复。 假设长度小于该常量。

具有函数返回大小和 RVO std::array<unsigned, 1024>

编写一个合并结果的包装函数(创建所有结果,然后合并它们)。 然后将平行模式库抛出问题(因此结果在平行中计算)。

如果您允许自己使用 SSE(或更好的 AVX)指令集进行矢量化,您可以一次执行 4/8 比较,执行此操作两次,"和"结果,然后检索 4 个结果(-1 或 0)。同时,这展开了循环。

// Preload the bounds
__m128 lo= _mm_set_ps(lower);
__m128 up= _mm_set_ps(upper);
int srcIndex, dstIndex= 0;
for (srcInd= 0; srcInd + 3 < arrLen; )
{
  __m128 src= _mm_load_ps(&srcArr[srcInd]); // Load 4 values
  __m128 tst= _mm_and_ps(_mm_cmple_ps(src, lo), _mm_cmpge_ps(src, up)); // Test
  // Copy the 4 indexes with conditional incrementation
  dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[0];
  dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[1];
  dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[2];
  dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[3];
}

注意:未选中的代码。

最新更新