我正在阅读JAX.local_devices上的JAX文档,其中写道:
类似jax.devices((,但只返回给定进程的本地设备。
在jax.devices((中,它被写为:
返回给定后端的所有设备的列表。
我不知道这些本地和非本地设备到底是什么。你能详细说明一下它们之间的区别吗?
这在JAX的文档《在多主机和多进程环境中使用JAX:》中进行了讨论
进程的本地设备是它可以直接寻址和启动计算的设备。例如,在GPU集群上,每个主机只能在直接连接的GPU上启动计算。在Cloud TPU吊舱上,每个主机只能在直接连接到该主机的8个TPU核心上启动计算(有关更多详细信息,请参阅Cloud TPU系统架构文档(。您可以通过
jax.local_devices()
查看进程的本地设备。全局设备是指所有进程中的设备只要每个进程在其本地设备上启动计算,计算就可以跨越进程之间的设备,并通过设备之间的直接通信链路执行集体操作。您可以通过
jax.devices()
查看所有可用的全局设备。进程的本地设备始终是全局设备的子集。