我目前使用tfjs 3.8在客户端加载细分模型(加载为tf.GraphModel
)。为了创建输入Tensor
,我调用browser.fromPixels(imageData)
,它从也在CPU上的ImageData
对象创建CPU上的Tensor
。由于我使用的是tfjs的webgl
后端,因此在调用model.predict(tensor)
函数时将数据发送到GPU。所有这些都工作得很好,除了我的ImageData
对象是从具有WebGLRenderingContext
的画布上的图像创建的,这意味着它来自GPU。这个GPU- CPU- GPU数据传输减慢了我的进程,我正在努力优化。
我简要地搜索了tfjs,无法找到在GPU上创建Tensor
以防止GPU- CPU数据传输的方法。有没有办法让我把数据保存在GPU上?
关于这个主题的详细对话在一个线程https://github.com/tensorflow/tfjs/issues/5765
解决方案是简单地为browser.fromPixels(canvas)
调用提供带有webgl上下文的画布。这将直接在GPU上创建张量。