我正在尝试学习Pytorch框架中的一些函数,并且在规范化一个简单的整数张量时由于以下错误而被卡住。有人能帮我一下吗?
下面是重现错误的示例代码-
import torch
import torch.nn as nn
#Integer type tensor
test_int_input = torch.randint(size = [3,5],low=1,high=9)
# BatchNorm1D object
batchnorm1D = nn.BatchNorm1d(num_features=5)
test_output = batchnorm1D(test_int_input)
错误——
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-38-6c672cd731fa> in <module>
1 batchnorm1D = nn.BatchNorm1d(num_features=5)
----> 2 test_output = batchnorm1D(test_input)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
105 input, self.running_mean, self.running_var, self.weight, self.bias,
106 self.training or not self.track_running_stats,
--> 107 exponential_average_factor, self.eps)
108
109
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
1668 return torch.batch_norm(
1669 input, weight, bias, running_mean, running_var,
-> 1670 training, momentum, eps, torch.backends.cudnn.enabled
1671 )
1672
RuntimeError: "batch_norm" not implemented for 'Long'
然而,如果我们尝试在不同的非int张量上应用相同的方法,那么它就可以工作。下面是示例-
import torch
import torch.nn as nn
#Integer type tensor
#test_input = torch.randn(size = [3,5])
# BatchNorm1D object
batchnorm1D = nn.BatchNorm1d(num_features=5)
test_output = batchnorm1D(test_input)
test_output
输出——
tensor([[ 0.4311, -1.1987, 0.9059, 1.1424, 1.2174],
[-1.3820, 1.2492, -1.3934, 0.1508, 0.0146],
[ 0.9509, -0.0505, 0.4875, -1.2931, -1.2320]],
grad_fn=<NativeBatchNormBackward>)
你的输入张量应该是一个浮点数:
>>> batchnorm1D(test_int_input.float())
tensor([[-5.9605e-08, -1.3887e+00, -9.8058e-01, 2.6726e-01, 1.4142e+00],
[-1.2247e+00, 4.6291e-01, 1.3728e+00, -1.3363e+00, -7.0711e-01],
[ 1.2247e+00, 9.2582e-01, -3.9223e-01, 1.0690e+00, -7.0711e-01]],
grad_fn=<NativeBatchNormBackward0>)