我是JAX的新手,我想使用多个GPU。到目前为止,我的JAX可以看到两个GPU(0和1(。
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
当我创建一个NumPy对象时,它将始终位于GPU设备0中,我认为这是默认的。
nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0
如何将我的变量nmp
发送到另一个GPUgpu:1
中?
使用.device_put()
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0
nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1