Detectron2中用于对象检测的输入图像的类型



我正在使用Detectron2来训练更快的R-CNN模型来进行对象检测,我想训练模型动物园给出的模型,输入范围为[01],而不是[0255],所以我使用了一个颜色变换,它调用我的函数scale_transform

def scale_transform(img):
return img/255.

此函数接收一个numpy数组并按比例返回。但是,在列车时刻,这个错误出现在

RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same

有人知道我该怎么解决这个问题吗?或者用另一种方法缩放探测器2的图像?

感谢

我认为这里的相关单词是类型

也许要确保输入被定义为浮点值。尽管它在正确的范围(0-1(内,但可能会发现数据类型不正确,因此会在那里绊倒。

以下可能对它-

def scale_transform(img):
img = img/255
img = img.astype(np.float32)
return img

最新更新