检测数组中大于阈值的值,并使用推力将结果存储在二进制(1/0)数组中



给定一个输入数组和一个阈值,我需要创建一个输出二进制数组,其中1表示大于阈值的值,0表示小于阈值的值。我需要使用推力。

我下面的尝试解决了问题,但看起来很笨拙如何一步到位。我的目标是在最短的计算时间内完成它

#include <thrust/replace.h>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <thrust/device_vector.h>
int main(int argc, char * argv[])
{
int threshold=1;
thrust::device_vector<int> S(6);
S[0] = 1;
S[1] = 2;
S[2] = 3;
S[3] = 4;
S[4] = 5;
S[5] = 6;
// fill vector with zeros
thrust::device_vector<int> A(6);
thrust::fill(thrust::device, A.begin(), A.end(), 0);
// detect indices with values greater than zero
thrust::device_vector<int> indices(6);
thrust::device_vector<int>::iterator end = thrust::copy_if(thrust::make_counting_iterator(0),thrust::make_counting_iterator(6),S.begin(),indices.begin(),                                                              thrust::placeholders::_1 > threshold);
int size = end-indices.begin();
indices.resize(size);
// use permutation iterator along with indices above to change to ones
thrust::replace(thrust::device,thrust::make_permutation_iterator(A.begin(), indices.begin()), thrust::make_permutation_iterator(A.begin(), indices.end()), 0, 1);
for (int i=0;i<6;i++)
{
std::cout << "A["<<i<<"]=" << A[i] << std::endl;
}
return 0;
}

索引检测部分取自这个堆栈溢出问题

只需使用自定义比较函数调用thrust::transform即可实现所需的功能。下面是上述方法的一个例子。

#include <thrust/execution_policy.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>    
template<class T>
struct thresher
{
T _thresh;
thresher(T thresh) : _thresh(thresh) { }
__host__ __device__ int operator()(T &x) const
{
return int(x > _thresh);
}
};
int main(int argc, char * argv[])
{
int threshold = 1;
thrust::device_vector<int> S(6);
S[0] = 1;
S[1] = 2;
S[2] = 3;
S[3] = 4;
S[4] = 5;
S[5] = 6;
thrust::device_vector<int> A(6);
thrust::transform(S.begin(), S.end(), A.begin(), thresher<int>(threshold));
for (int i=0;i<6;i++)
{
std::cout << "A["<<i<<"]=" << A[i] << std::endl;
}
return 0;
}

最新更新