PyTorch:从设备能力推断数据类型,而不是输入数据



我正试图为PyTorch编写一个与设备无关的库,但我偶然发现PyTorch使用与我的计算设备不兼容的数据类型的问题:

import scipy.signal
import torch
raw_window = scipy.signal.windows.cosine(128)
print(raw_window.dtype)  # float64
device = torch.device("cpu")
window = torch.as_tensor(raw_window, device=device)
print(window.device)  # cpu
print(window.dtype)   # torch.float64
device = torch.device("cuda")
window = torch.as_tensor(raw_window, device=device)
print(window.device)  # cuda:0
print(window.dtype)   # torch.float64

正如您在最后一行中看到的,torch指定了数据类型torch.float64,尽管我的CUDA设备无法处理双精度浮点值。

有没有办法让PyTorch使用最合适的设备数据类型而不是输入数据数据类型?还是我完全误解了设置设备和数据类型的整个概念?

否,正如您所注意到的,PyTorch仅从input数据推断dtype

在您的情况下,由于numpy的默认设置为np.float64(无论系统和体系结构如何(,PyTorch将推断它类似于torch.float64,因此从numpy开始更麻烦(并且您不能设置不同的默认dtype(。

pytorch中,您通常选择torch.float32(这是默认值(,最终选择torch.float16以获得混合精度(不包括量化(。通常这样的精度就足够了(但不确定您的确切用例(。

因此,您最好的选择是castnumpy阵列到np.float32,如下所示:

raw_window = raw_window.astype(np.float32)

(或者通过CCD_ 15进行投射,或者甚至更好地从一开始就创建CCD_,该AFAIK应该在所有/大多数设备上可用,无论是CPU还是GPU。

对于CPU/GPU不可知的方法,您应该进行显式设备创建,如下所示:

if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
else:
args.device = torch.device('cpu')

(如果需要显式指定cuda或类似内容,则进行扩展(,并在创建tensor时使用它。

相关内容

最新更新