使用Thrust对Cuda中的2D阵列进行排序



我有一个2d数组,我想按行排序,这意味着如果数组是

 3     2     2     3     2     2     3     3     3     3
 3     3     2     2     2     2     3     3     2     2
 3     2     2     3     2     2     3     3     3     2
 2     2     2     2     2     2     2     2     2     2
 3     2     2     2     2     2     3     2     2     2
 2     2     2     2     2     2     2     2     2     2
 3     3     2     3     2     2     3     3     2     3
 3     3     2     2     2     2     3     3     3     3
 3     2     2     3     2     2     3     3     2     3
 3     3     2     3     2     2     3     3     3     3

我想拿阵列

 2     2     2     2     2     2     2     2     2     2
 2     2     2     2     2     2     2     2     2     2
 3     2     2     2     2     2     3     2     2     2
 3     2     2     3     2     2     3     3     2     3
 3     2     2     3     2     2     3     3     3     2
 3     2     2     3     2     2     3     3     3     3
 3     3     2     2     2     2     3     3     2     2
 3     3     2     2     2     2     3     3     3     3
 3     3     2     3     2     2     3     3     2     3
 3     3     2     3     2     2     3     3     3     3

我检查了纯CUDA中基数排序的一些实现,但它们看起来相当复杂。使用Thrust有没有相对简单的方法?

在推力中可以做到这一点。一种可能的方法是创建一个自定义排序函子,遍历给定给它的行(假设这些行是通过传递给函子的索引来识别的),然后决定这些行的顺序。

为了实现这一点,我们可以创建一个索引数组,每行一个索引,我们将对其进行排序。我们将根据给定的数据数组对该索引数组进行排序(使用对行排序的自定义排序函数)。

最后,我们唯一排序的是索引数组,但如果需要的话,现在它的顺序是重新排列行。

下面是一个完整的例子:

$ cat t631.cu
#include <iostream>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/sort.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#define DWIDTH 10
typedef int mytype;
struct my_sort_functor
{
  int my_width;
  mytype *my_data;
  my_sort_functor(int _my_width, mytype * _my_data): my_width(_my_width), my_data(_my_data) {};
  __host__ __device__
  bool operator()(const int idx1, const int idx2) const
    {
      bool flip = false;
      for (int col_idx = 0; col_idx < my_width; col_idx++){
        mytype d1 = my_data[(idx1*my_width)+col_idx];
        mytype d2 = my_data[(idx2*my_width)+col_idx];
        if (d1 > d2) break;
        if (d1 < d2) {flip = true; break;}
        }
      return flip;
    }
};
int main(){
  mytype data[] = {
    3,     2,     2,     3,     2,     2,     3,     3,     3,     3,
    3,     3,     2,     2,     2,     2,     3,     3,     2,     2,
    3,     2,     2,     3,     2,     2,     3,     3,     3,     2,
    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
    3,     2,     2,     2,     2,     2,     3,     2,     2,     2,
    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
    3,     3,     2,     3,     2,     2,     3,     3,     2,     3,
    3,     3,     2,     2,     2,     2,     3,     3,     3,     3,
    3,     2,     2,     3,     2,     2,     3,     3,     2,     3,
    3,     3,     2,     3,     2,     2,     3,     3,     3,     3 };
  int cols  = DWIDTH;
  int dsize = sizeof(data)/sizeof(mytype);
  int rows  = dsize/cols;
  thrust::host_vector<mytype>   h_data(data, data+dsize);
  thrust::device_vector<mytype> d_data = h_data;
  thrust::device_vector<int> idxs(rows);
  thrust::sequence(idxs.begin(), idxs.end());
  thrust::sort(idxs.begin(), idxs.end(), my_sort_functor(cols, thrust::raw_pointer_cast(d_data.data())));
  thrust::host_vector<int> h_idxs = idxs;
  for (int i = 0; i<rows; i++){
    thrust::copy(h_data.begin()+h_idxs[i]*cols, h_data.begin()+(h_idxs[i]+1)*cols, std::ostream_iterator<mytype>(std::cout, ", "));
    std::cout << std::endl;}
  return 0;
}
$ nvcc -o t631 t631.cu
$ ./t631
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 2, 2, 2, 2, 2, 3, 2, 2, 2,
3, 2, 2, 3, 2, 2, 3, 3, 2, 3,
3, 2, 2, 3, 2, 2, 3, 3, 3, 2,
3, 2, 2, 3, 2, 2, 3, 3, 3, 3,
3, 3, 2, 2, 2, 2, 3, 3, 2, 2,
3, 3, 2, 2, 2, 2, 3, 3, 3, 3,
3, 3, 2, 3, 2, 2, 3, 3, 2, 3,
3, 3, 2, 3, 2, 2, 3, 3, 3, 3,
$

我敢肯定,如果数据可以以转置的形式传递,并重新排列代码以对列而不是行进行排序(即,根据数据数组中的列而不是列对索引向量进行排序),这将大大提高效率。对于由排序函数驱动的底层数据访问,这将更有效。

我省略了将行移动到新位置的步骤,但希望这应该是简单的。在输出结果的方法中暗示了通用方法,尽管如果需要,可以使用单个推力调用来完成。

最新更新