我有一个类型为double
的(相当大的(标准C++
数组,具有~50,000,000
行和20
列。根据高斯分布(如果这在回答这个问题时有任何用处的话(,数组中填充了随机数据。
我已经写了一个算法来解决使用这个数组的问题。该算法的大部分时间用于逐行迭代(有时在同一行上迭代不止一次(,并为每一行返回该行中每个元素的索引,以使该元素的绝对值超过某个值(也是double
类型(。
不幸的是,算法相当慢。由于它相当大,而且简单地将代码转储到SO上所解决的问题有点复杂,所以我想从解决这个问题开始。获取多维数组行中每个元素的索引最有效(或者至少是更有效的方法(是什么?
我尝试过的:
我尝试过简单地迭代每一行(使用迭代器(,将每个值传递给fabs()
,并使用std::distance()
来获取索引。然后,我将其存储在std::set()
中(我不太关心索引是如何存储的,除非这是一个重要的速度因素,只要它们"易于访问"(。
即:
for(auto it = row.begin(); it != row.end(); ++it){
auto &element = *it;
if(fabs(element) >= threshold){
cache.insert(std::distance(row.begin(), it));
}
}
我也尝试过使用std::find_if
,类似地使用std::range
。两者都没有给出可衡量的速度改进(诚然,我没有使用特别科学的基准,但我会寻求明显的改进(。
例如:
auto exceeds_thresh = [](double x){ return x > threshold}
it = ranges::find_if(row, exceeds_thresh);
while(it != end(row)){
resuts.emplace_back(distance(begin(row), it));
it = ranges::find_if(std::next(it), std::end(row), exceeds_thresh)
}
注意,根据效率,我专注于速度
这里,11.3, 9.8, 17.5
满足条件,因此应该打印它们的索引1,3,6
。注意,在实践中,每个数组都是一个大得多的数组中的一行(如上所述(,并且每行中的元素数量要多得多:
double row_of_array[5] = {1.4, 11.3, 4.2, 9.8, 0.1, 3.2, 17.5};
double threshold = 8;
for(auto it = row_of_array.begin(); it != row_of_array.end(); ++it){
auto &element = *it;
if(fabs(element) > threshold){
std::cout << std::distance(row_of_array.begin(), it) << "n";
}
}
您可以尝试循环展开
double row_of_array[] = {1, 11, 4, 9, 0, 3, 17};
constexpr double threshold = 8;
std::vector<int> results;
results.reserve(20);
for(int i{}, e = std::ssize(row_of_array); i < e; i += 4)
{
if(std::abs(row_of_array[i]) > threshold)
results.push_back(i);
if(i + 1 < e && std::abs(row_of_array[i + 1]) > threshold)
results.push_back(i + 1);
if(i + 2 < e && std::abs(row_of_array[i + 2]) > threshold)
results.push_back(i + 2);
if(i + 3 < e && std::abs(row_of_array[i + 3]) > threshold)
results.push_back(i + 3);
}
编辑:
或者风险更大的
double row_of_array[20] = {1, 11, 4, 9, 0, 3, 17};
constexpr double threshold = 8;
std::vector<int> results;
results.reserve(20);
static_assert(std::ssize(row_of_array) % 4 == 0, "only works for mul of 4");
for(int i{}, e = std::ssize(row_of_array); i < e; i += 4)
{
if(std::abs(row_of_array[i]) > threshold) results.push_back(i);
if(std::abs(row_of_array[i + 1]) > threshold) results.push_back(i + 1);
if(std::abs(row_of_array[i + 2]) > threshold) results.push_back(i + 2);
if(std::abs(row_of_array[i + 3]) > threshold) results.push_back(i + 3);
}