我正试图为PyTorch编写一个与设备无关的库,但我偶然发现PyTorch使用与我的计算设备不兼容的数据类型的问题:
import scipy.signal
import torch
raw_window = scipy.signal.windows.cosine(128)
print(raw_window.dtype) # float64
device = torch.device("cpu")
window = torch.as_tensor(raw_window, device=device)
print(window.device) # cpu
print(window.dtype) # torch.float64
device = torch.device("cuda")
window = torch.as_tensor(raw_window, device=device)
print(window.device) # cuda:0
print(window.dtype) # torch.float64
正如您在最后一行中看到的,torch指定了数据类型torch.float64
,尽管我的CUDA设备无法处理双精度浮点值。
有没有办法让PyTorch使用最合适的设备数据类型而不是输入数据数据类型?还是我完全误解了设置设备和数据类型的整个概念?
否,正如您所注意到的,PyTorch仅从input
数据推断dtype
。
在您的情况下,由于numpy
的默认设置为np.float64
(无论系统和体系结构如何(,PyTorch将推断它类似于torch.float64
,因此从numpy
开始更麻烦(并且您不能设置不同的默认dtype
(。
在pytorch
中,您通常选择torch.float32
(这是默认值(,最终选择torch.float16
以获得混合精度(不包括量化(。通常这样的精度就足够了(但不确定您的确切用例(。
因此,您最好的选择是cast
numpy
阵列到np.float32
,如下所示:
raw_window = raw_window.astype(np.float32)
(或者通过CCD_ 15进行投射,或者甚至更好地从一开始就创建CCD_,该AFAIK应该在所有/大多数设备上可用,无论是CPU还是GPU。
对于CPU/GPU不可知的方法,您应该进行显式设备创建,如下所示:
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
else:
args.device = torch.device('cpu')
(如果需要显式指定cuda
或类似内容,则进行扩展(,并在创建tensor
时使用它。