运行时错误:无法推断生成器的数据类型



我使用pytorch为模型构建训练数据。

def shufflerow(tensor1, tensor2, axis):
row_perm = torch.rand(tensor1.shape[:axis+1]).argsort(axis)  # get permutation indices
for _ in range(tensor1.ndim-axis-1): row_perm.unsqueeze_(-1)
row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor1.shape[axis+1:]))  # reformat this for the gather operation
return tensor1.gather(axis, row_perm),tensor2.gather(axis, row_perm)
class Dataset:
def __init__(self, observation, next_observation):
self.data =(observation, next_observation)
indices = torch.randperm(observation.shape[0])
self.train_samples = (observation[indices ,:], next_observation[indices ,:])
self.test_samples = shufflerow(observation, next_observation, 0)

我还有这个功能,可以检查数据是否转换为torch.ttensor并设置设备

def to_tensor(x, device):
if torch.is_tensor(x):
return x
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).to(device=device, dtype=torch.float32)
elif isinstance(x, list):
if all(isinstance(item, np.ndarray) for item in x):
return [torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x]
elif isinstance(x, tuple):
return (torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)
else:
print(f"X:{x} and X's type{type(x)}") 
return torch.tensor(x).to(device=device, dtype=torch.float32)

但是通过Dataset类传递基本上看起来像这样的输入数据data=数据集(s1,s2(打印(data.train_samples(

(tensor([[-0.3121, -0.9500,  1.4518],
[-0.9903, -0.1391, -4.4141],
[-0.9645, -0.2642,  5.0233],
[-0.6413,  0.7673, -4.5495],
[-0.3073,  0.9516, -1.0128],
[-0.5495,  0.8355,  3.4044],
[-0.5710, -0.8209, -3.2716],
[-0.9388,  0.3445,  3.9225],
[-0.8402, -0.5423, -4.0820]]), tensor([[-0.2723, -0.9622,  0.8342],
[-0.9958,  0.0912, -4.6186],
[-0.8747, -0.4847,  4.7741],
[-0.5495,  0.8355,  3.4044],
[-0.7146,  0.6996,  4.2841],
[-0.7128, -0.7014, -3.7148],
[-0.9915,  0.1303,  4.4200],
[-0.9358, -0.3526, -4.2585]]))

我收到这个错误消息

-> 1725         self._target_samples = to_tensor(true_samples)
1726         self._steps = []

/content/data_gen.py in to_tensor(x)
1368     else:
1369         print(f"X:{x} and X's type{type(x)}")
-> 1370         return torch.tensor(x).to(device=device, dtype=torch.float32)

X:<generator object to_tensor.<locals>.<genexpr> at 0x7f380235d6d0> and X's type<class 'generator'>
RuntimeError: Could not infer dtype of generator

有什么建议吗,为什么我会出现这个错误?

表达式(torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)不是在创建元组,而是一个生成器表达式。由于这是在测试元组的情况下,我怀疑您想要的是元组而不是生成器。尝试:

elif isinstance(x, tuple):
return tuple(torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)

最新更新