在输入和输出整数数组的神经网络中,我应该使用什么 dtype 来定义 PyTorch 参数?



我目前正在 PyTorch 中构建一个神经网络,它接受整数张量并输出整数量。 只有少数正整数被"允许">(如 0、1、2、3 和 4(作为输入和输出张量的元素。

神经网络通常在连续空间中工作。 例如,层之间的非线性激活函数是连续的,并将整数映射到实数(包括非整数(。

是否最好在内部使用无符号整数(如torch.uint8(来处理网络的权重和偏差,以及一些将 int 映射到 int 的自定义激活函数?

还是我应该使用像torch.float32这样的高精度浮点数,然后通过将实数装箱到最接近的整数来最后进行舍入? 我认为第二种策略是要走的路,但也许我错过了一些可以很好地工作的东西。

在不太了解您的应用程序的情况下,我会选择四舍五入torch.float32。主要原因是,如果您使用 GPU 来计算神经网络,它将需要float32数据类型和数据。如果你不打算训练你的神经网络,你想在CPU上运行,那么像torch.uint8这样的数据类型可能会帮助你,因为你可以在每个时间间隔内实现更多的指令(即你的应用程序应该运行得更快(。如果这没有给您留下线索,那么请更具体地说明您的应用程序。

相关内容

最新更新