立方体函数的输入和输出重叠



我正在使用用于矩阵乘法的cublas库处理一些大数据。为了节省内存空间,我想要类似A=A*B的东西,其中AB都是n乘n的平方矩阵,即对输出和其中一个输入矩阵使用相同的内存空间。

虽然一些老帖子说这在cublas库中是不允许的,但我实际上是使用cublasZgemmStridedBatched()函数实现的。令人惊讶的是,计算是完全正确的,并且经过反复运行是稳定的。所以我想知道当前的cublas库是否支持重叠的输入和输出。如果是,它实际节省了多少内存?我的意思是,直观地说,该函数至少需要一些额外的内存来存储中间计算,因为Aij= AikBkj依赖于一整行的A。这对于批处理gems来说是不是特别节省内存?

虽然一些旧帖子说这在cublas库中是不允许的,

它们是完全正确的(请注意,"旧帖子"指的是标准GEMM调用,而不是您所询问的批处理实现(。

我实际上是使用cublasZgemmStridedBatched()函数实现的。令人惊讶的是,计算是完全正确的,并且在重复运行的情况下是稳定的

这并没有被证明是安全的,我怀疑你可能只是幸运地得到了稳定的结果,因为小矩阵可能被预加载到共享内存或寄存器中,因此就地操作可以工作。如果你使用更大的矩阵,我想你会看到失败,因为最终会出现这样的情况:在一个写周期后,如果不多次访问源矩阵,就无法执行单个GEMM,这会损坏源矩阵。

我不建议就地操作,即使你发现它适用于一种情况。不同的问题大小、库版本和硬件可能会产生您根本没有测试过的故障。选择和相关风险由您决定。

最新更新