计算金属核的平均值



谁知道一个正确的方法来计算与随机浮点数在金属内核的缓冲区的平均值?

计算命令编码器上的调度工作:

threadsPerGroup = MTLSizeMake(1, 1, inputTexture.arrayLength);
numThreadGroups = MTLSizeMake(1, 1, inputTexture.arrayLength / threadsPerGroup.depth);
[commandEncoder dispatchThreadgroups:numThreadGroups
               threadsPerThreadgroup:threadsPerGroup];
内核代码:

kernel void mean(texture2d_array<float, access::read> inTex [[ texture(0) ]],
             device float *means                            [[ buffer(1) ]],
             uint3 id                                       [[ thread_position_in_grid ]]) {
    if (id.x == 0 && id.y == 0) {
        float mean = 0.0;
        for (uint i = 0; i < inTex.get_width(); ++i) {
            for (uint j = 0; j < inTex.get_height(); ++j) {
                    mean += inTex.read(uint2(i, j), id.z)[0];
            }
        }
        float textureArea = inTex.get_width() * inTex.get_height();
        mean /= textureArea;
        out[id.z] = mean;
    }
}

该缓冲区以texture2d_array类型的纹理表示,采用R32Float像素格式。

如果您可以使用int数组(而不是float)作为数据源,我建议使用"原子获取和修改函数"(如金属着色语言规范中所述)自动写入缓冲区。

下面是一个内核函数的例子,它接受一个输入缓冲区(data:一个Float数组),并将缓冲区的和写入一个原子缓冲区(sum,一个指向int类型的指针):

kernel void sum(device uint *data [[ buffer(0) ]],
                volatile device atomic_uint *sum [[ buffer(1) ]],
                uint gid [[ thread_position_in_grid ]])
{
    atomic_fetch_add_explicit(sum, data[gid], memory_order_relaxed);
}

在你的swift文件中,你可以设置缓冲区:

...
let data: [UInt] = [1, 2, 3, 4]
let dataBuffer = device.makeBuffer(bytes: &data, length: (data.count * MemoryLayout<UInt>.size), options: [])
commandEncoder.setBuffer(dataBuffer, offset: 0, at: 0)
var sum:UInt = 0
let sumBuffer = device!.makeBuffer(bytes: &sum, length: MemoryLayout<UInt>.size, options: [])
commandEncoder.setBuffer(sumBuffer, offset: 0, at: 1)
commandEncoder.endEncoding()

提交,等待,然后从GPU获取数据:

commandBuffer.commit()
commandBuffer.waitUntilCompleted()
let nsData = NSData(bytesNoCopy: sumBuffer.contents(),
                        length: sumBuffer.length,
                        freeWhenDone: false)
nsData.getBytes(&sum, length:sumBuffer.length)
let mean = Float(sum/data.count)
print(mean)

或者,如果您的初始数据源必须是浮点数组,您可以使用加速框架的vDSP_meanv方法,该方法对于此类计算非常快。

我希望这有帮助,干杯!

最新更新