#define TS 32
int num_devices = 0;
__global__ void shared_kernel(float* A, float* B, float* C, int M, int N, int K) {
int global_col = blockDim.x * blockIdx.x + threadIdx.x;
int global_row = blockDim.y * blockIdx.y + threadIdx.y;
int local_col = threadIdx.x;
int local_row = threadIdx.y;
if (global_row >= M || global_col >= N) return;
__shared__ float Asub[TS][TS];
__shared__ float Bsub[TS][TS];
const int num_tiles = K / TS;
float acc = 0;
for(int t = 0; t < num_tiles; t++){
const int t_row = TS * t + local_row;
const int t_col = TS * t + local_col;
Asub[local_row][local_col] = A[global_row * K + t_col];
Bsub[local_row][local_col] = B[t_row * N + global_col];
__syncthreads();
printf("[DEBUG] first sync threads, global_row: %d, global_col: %dn", global_row, global_col);
for (int k = 0; k < K; ++k) {
acc += Asub[local_row][k] * Bsub[k][local_col];
}
__syncthreads();
printf("[DEBUG] second sync threads, global_row: %d, global_col: %dn", global_row, global_col);
}
C[global_row * N + global_col] = acc;
}
static float *a_d, *b_d, *c_d;
void mat_mul(float *A, float *B, float *C, int M, int N, int K) {
cudaMemcpy(a_d, A, M * K * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(b_d, B, K * N * sizeof(float), cudaMemcpyHostToDevice);
dim3 blockDim(TS, TS);
dim3 gridDim(M/TS, N/TS);
shared_kernel<<<gridDim, blockDim>>>(a_d, b_d, c_d, M, N, K);
cudaMemcpy(C, c_d, M * N * sizeof(float), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
}
void mat_mul_init(float *A, float *B, float *C, int M, int N, int K) {
cudaGetDeviceCount(&num_devices);
cudaSetDevice(0);
cudaMalloc(&a_d, M * K * sizeof(float));
cudaMalloc(&b_d, K * N * sizeof(float));
cudaMalloc(&c_d, M * N * sizeof(float));
}
上面的例子是一个共享内存的矩阵乘法。我在内核之上运行了dim3 blockDim(TS, TS)
和dim3 gridDim(M/TS, N/TS)
,并且M,N,K=128。
在启动内核后,我检查了float * C
是否为零值。此外,我发现在第一个__syncthreads()
之后只打印了一些global_row(从37到81(,在第二个__syncthreads()
之后没有printf
DEBUG消息。
我怀疑是__syncthreads()
导致了这个问题,但我不知道如何解决。我的代码和其他网站上的其他矩阵乘法代码几乎相同。
你能给我一些提示如何解决这个问题吗?
任何时候CUDA代码出现问题时,我建议您使用正确的CUDA错误检查,并使用compute-sanitizer
或cuda-memcheck
运行代码。对于这种类型的分析,如果不在内核printf
中使用,则会更容易。
如果你这样做,你会看到这样的输出:
========= Invalid __shared__ read of size 4
========= at 0x000002f0 in shared_kernel(float*, float*, float*, int, int, int)
========= by thread (0,2,0) in block (0,1,0)
========= Address 0x00002000 is out of bounds
========= Saved host backtrace up to driver entry point at kernel launch time
... (and more output)
因此,我们可以看到您的内核正在进行无效的__shared__
读取操作。这种情况在内核中发生在哪里?您可以使用此处的方法来识别特定的代码行。然而,这是一个相当简单的内核,并且只有一行是从共享内存中读取的,它在这里:
for (int k = 0; k < K; ++k) {
acc += Asub[local_row][k] * Bsub[k][local_col]; // shared reads here
快速检查将显示,如果您让这个循环在K=128
的范围内迭代,那么您将在这里索引越界:
for (int k = 0; k < K; ++k) {
acc += Asub[local_row][k] * Bsub[k][local_col];
^ ^
当k
大于31时,因为这将超过您的共享数组维度:
#define TS 32
__shared__ float Asub[TS][TS];
__shared__ float Bsub[TS][TS];
我不会为您编写一个固定的内核/代码,因为正如您已经指出的,这个主题在许多其他地方都有介绍,编程指南中已经提供了一个规范的示例。
FWIW,如果我把你的for循环改成这个:
for (int k = 0; k < TS; ++k) {
那么运行时错误对我来说就消失了。cuda-memcheck
报告没有错误。