使用jCUDA对复杂矩阵进行运算



使用jCuda处理复数的最佳方法是什么?我应该使用cuComplex格式还是有其他解决方案(比如一个实部和虚部相继的数组)?我真的很欣赏使用这种类型计算的java代码示例。

由于我的目的是使用GPU求解具有复数的大型线性方程组,我不想只关注jCuda。使用GPU进行此类计算的替代方法是什么?

首先,关于在GPU上使用Java进行计算的问题,我在这里写了几句话。

你的申请案例似乎很具体。您可能应该更详细地描述您的实际意图,因为这将支配所有的设计决策。到目前为止,我只能给出一些基本的提示。由您决定哪种解决方案最合适。


在Java世界和GPU世界之间进行桥接时的主要困难之一是内存处理的根本不同。

C/C中的内存布局++

CUDA中的cuComplex结构体定义为

typedef float2 cuFloatComplex
typedef cuFloatComplex cuComplex;

其中float2基本上类似于

struct float2 {
float x; 
float y; 
};

(带有一些用于对齐等的附加说明符)

现在,当你在C/C++程序中分配一个cuComplex值的数组时,你只需要写一些类似的东西

cuComplex *c = new cuComplex[100];

在这种情况下,保证所有这些cuComplex值的内存将是单个连续内存块。这个内存块只是由复数的所有xy值组成,一个接一个:

_____________________________
c -> | x0 | y0 | x1 | y1 | x2 | y2 |... 
|____|____|____|____|____|____|

这个连续的内存块可以很容易地复制到设备上:一个获取指针,并调用类似的调用

cudaMemcpy(device, c, sizeof(cuComplex)*n, cudaMemcpyHostToDevice);

Java中的内存布局

考虑这样一种情况:创建一个在结构上等于cuComplex结构的Java类,并分配一个数组:

class cuComplex {
public float x;
public float y;
}
cuComplex c[] = new cuComplex[100];

那么您就不会有一个连续的float值内存块。相反,您有一个引用cuComplex对象的数组,并且相应的xy值分散在各处:

____________________
c -> |  c0  |  c1  |  c2  |... 
|______|______|______|
|       |      |
v       v      v
[x,y]   [x,y]  [x,y]

这里的关键点是:

您无法将cuComplex对象的(Java)数组复制到设备


这有几个含义。在注释中,您已经提到了以cuComplex数组为参数的cublasSetVector方法,我试图强调这不是最有效的解决方案,但这只是为了方便起见。事实上,这种方法的工作方式是在内部创建一个新的ByteBuffer,以便具有连续的内存块,用cuComplex[]数组中的值填充该ByteBuffer,然后将该ByteBuffer复制到设备。

当然,这会增加您在性能关键型应用程序中最希望避免的开销。


有几个选项可以解决这个问题。幸运的是,对于复数,解决方案相对简单:

不要使用cuComplex结构来表示复数数组

相反,您应该将复数数组表示为单个连续的内存块,其中复数的实部和虚部交错,分别为单个floatdouble值。这将允许不同后端之间实现最大的互操作性(省去某些细节,如对齐要求)。

不幸的是,这可能会造成一些不便并引发一些问题,而且没有一个一刀切的解决方案。

如果试图对此进行概括,不仅指复数,而且指一般的"结构",那么就可以应用一种"模式":可以为结构创建接口,并创建这些结构的集合,该集合是实现该接口的类的实例列表,这些实例都由一个连续的内存块支持。这可能适用于某些情况。但对于复数,为每个复数拥有一个Java对象的内存开销可能会大得令人望而却步。

仅处理原始float[]double[]阵列的另一个极端也可能不是最佳解决方案。例如:如果您有一个代表复数的float值数组,那么如何将其中一个复数与另一个相乘?

一个"中间"解决方案可以创建一个接口,允许访问复数的实部和虚部。在实现中,这些复数存储在单个阵列中,如上所述。


我在这里勾画了这样一个实现。

注意:

这只是一个例子,展示了基本思想,并展示了它如何与JCublas这样的东西协同工作。对你来说,根据你的实际目标,不同的策略可能更合适:除了JCuda之外,还应该有哪些其他后端?在Java端处理复数应该有多"方便"?Java端处理复数的结构(类/接口)应该是什么样子的?

简而言之:在继续实现之前,您应该非常清楚地了解应用程序/库应该能够做什么。

import static jcuda.jcublas.JCublas2.*;
import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N;
import static jcuda.runtime.JCuda.*;
import java.util.Random;
import jcuda.*;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.cudaMemcpyKind;
// An interface describing an array of complex numbers, residing
// on the host, with methods for accessing the real and imaginary
// parts of the complex numbers, as well as methods for copying
// the underlying data from and to the device
interface cuComplexHostArray
{
int size();
float getReal(int i);
float getImag(int i);
void setReal(int i, float real);
void setImag(int i, float imag);
void set(int i, cuComplex c);
void set(int i, float real, float imag);
cuComplex get(int i, cuComplex c);
void copyToDevice(Pointer devicePointer);
void copyFromDevice(Pointer devicePointer);
}
// A default implementation of a cuComplexHostArray, backed
// by a single float[] array
class DefaultCuComplexHostArray implements cuComplexHostArray
{
private final int size;
private final float data[];
DefaultCuComplexHostArray(int size)
{
this.size = size;
this.data = new float[size * 2];
}
@Override
public int size()
{
return size;
}
@Override
public float getReal(int i)
{
return data[i+i];
}
@Override
public float getImag(int i)
{
return data[i+i+1];
}
@Override
public void setReal(int i, float real)
{
data[i+i] = real;
}
@Override
public void setImag(int i, float imag)
{
data[i+i+1] = imag;
}
@Override
public void set(int i, cuComplex c)
{
data[i+i+0] = c.x;
data[i+i+1] = c.y;
}
@Override
public void set(int i, float real, float imag)
{
data[i+i+0] = real;
data[i+i+1] = imag;
}
@Override
public cuComplex get(int i, cuComplex c)
{
float real = getReal(i);
float imag = getImag(i);
if (c != null)
{
c.x = real;
c.y = imag;
return c;
}
return cuComplex.cuCmplx(real, imag);
}
@Override
public void copyToDevice(Pointer devicePointer)
{
cudaMemcpy(devicePointer, Pointer.to(data),
size * Sizeof.FLOAT * 2,
cudaMemcpyKind.cudaMemcpyHostToDevice);
}
@Override
public void copyFromDevice(Pointer devicePointer)
{
cudaMemcpy(Pointer.to(data), devicePointer,
size * Sizeof.FLOAT * 2,
cudaMemcpyKind.cudaMemcpyDeviceToHost);
}
}
// An example that performs a "gemm" with complex numbers, once
// in Java and once in JCublas2, and verifies the result
public class JCublas2ComplexSample
{
public static void main(String args[])
{
testCgemm(500);
}
public static void testCgemm(int n)
{
cuComplex alpha = cuComplex.cuCmplx(0.3f, 0.2f);
cuComplex beta  = cuComplex.cuCmplx(0.1f, 0.7f);
int nn = n * n;
System.out.println("Creating input data...");
Random random = new Random(0);
cuComplex[] rhA = createRandomComplexRawArray(nn, random);
cuComplex[] rhB = createRandomComplexRawArray(nn, random);
cuComplex[] rhC = createRandomComplexRawArray(nn, random);
random = new Random(0);
cuComplexHostArray hA = createRandomComplexHostArray(nn, random);
cuComplexHostArray hB = createRandomComplexHostArray(nn, random);
cuComplexHostArray hC = createRandomComplexHostArray(nn, random);
System.out.println("Performing Cgemm with Java...");
cgemmJava(n, alpha, rhA, rhB, beta, rhC);
System.out.println("Performing Cgemm with JCublas...");
cgemmJCublas(n, alpha, hA, hB, beta, hC);
boolean passed = isCorrectResult(hC, rhC);
System.out.println("testCgemm "+(passed?"PASSED":"FAILED"));
}
private static void cgemmJCublas(
int n,
cuComplex alpha,
cuComplexHostArray A,
cuComplexHostArray B,
cuComplex beta,
cuComplexHostArray C)
{
int nn = n * n;
// Create a CUBLAS handle
cublasHandle handle = new cublasHandle();
cublasCreate(handle);
// Allocate memory on the device
Pointer dA = new Pointer();
Pointer dB = new Pointer();
Pointer dC = new Pointer();
cudaMalloc(dA, nn * Sizeof.FLOAT * 2);
cudaMalloc(dB, nn * Sizeof.FLOAT * 2);
cudaMalloc(dC, nn * Sizeof.FLOAT * 2);
// Copy the memory from the host to the device
A.copyToDevice(dA);
B.copyToDevice(dB);
C.copyToDevice(dC);
// Execute cgemm
Pointer pAlpha = Pointer.to(new float[]{alpha.x, alpha.y});
Pointer pBeta = Pointer.to(new float[]{beta.x, beta.y});
cublasCgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n,
pAlpha, dA, n, dB, n, pBeta, dC, n);
// Copy the result from the device to the host
C.copyFromDevice(dC);
// Clean up
cudaFree(dA);
cudaFree(dB);
cudaFree(dC);
cublasDestroy(handle);
}
private static void cgemmJava(
int n,
cuComplex alpha,
cuComplex A[],
cuComplex B[],
cuComplex beta,
cuComplex C[])
{
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
cuComplex prod = cuComplex.cuCmplx(0, 0);
for (int k = 0; k < n; ++k)
{
cuComplex ab =
cuComplex.cuCmul(A[k * n + i], B[j * n + k]);
prod = cuComplex.cuCadd(prod, ab);
}
cuComplex ap = cuComplex.cuCmul(alpha, prod);
cuComplex bc = cuComplex.cuCmul(beta, C[j * n + i]);
C[j * n + i] = cuComplex.cuCadd(ap, bc);
}
}
}
private static cuComplex[] createRandomComplexRawArray(
int n, Random random)
{
cuComplex c[] = new cuComplex[n];
for (int i = 0; i < n; i++)
{
float real = random.nextFloat();
float imag = random.nextFloat();
c[i] = cuComplex.cuCmplx(real, imag);
}
return c;
}
private static cuComplexHostArray createRandomComplexHostArray(
int n, Random random)
{
cuComplexHostArray c = new DefaultCuComplexHostArray(n);
for (int i = 0; i < n; i++)
{
float real = random.nextFloat();
float imag = random.nextFloat();
c.setReal(i, real);
c.setImag(i, imag);
}
return c;
}
private static boolean isCorrectResult(
cuComplexHostArray result, cuComplex reference[])
{
float errorNormX = 0;
float errorNormY = 0;
float refNormX = 0;
float refNormY = 0;
for (int i = 0; i < result.size(); i++)
{
float diffX = reference[i].x - result.getReal(i);
float diffY = reference[i].y - result.getImag(i);
errorNormX += diffX * diffX;
errorNormY += diffY * diffY;
refNormX += reference[i].x * result.getReal(i);
refNormY += reference[i].y * result.getImag(i);
}
errorNormX = (float) Math.sqrt(errorNormX);
errorNormY = (float) Math.sqrt(errorNormY);
refNormX = (float) Math.sqrt(refNormX);
refNormY = (float) Math.sqrt(refNormY);
if (Math.abs(refNormX) < 1e-6)
{
return false;
}
if (Math.abs(refNormY) < 1e-6)
{
return false;
}
return
(errorNormX / refNormX < 1e-6f) &&
(errorNormY / refNormY < 1e-6f);
}
}

(顺便说一句:我可能会把这个答案的一部分扩展为JCuda的样本和/或"如何…"页面。提供这样的信息的任务已经在我的"待办事项"列表中列出了很长一段时间)。

最新更新