如何为GPU设备指定或设置变量



我是JAX的新手,我想使用多个GPU。到目前为止,我的JAX可以看到两个GPU(0和1(。

import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

当我创建一个NumPy对象时,它将始终位于GPU设备0中,我认为这是默认的。

nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0

如何将我的变量nmp发送到另一个GPUgpu:1中?

使用.device_put()

import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0
nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1

相关内容

  • 没有找到相关文章

最新更新