对MNIST数据使用shift().得到奇怪的结果



我正在尝试在MNIST图像上使用shift()函数。

然而,当我查看原始数据和移位的数据时,看起来移位的值正好是零,而不是零,变成了非常小的非零值。举个例子,在移动之前,这个值是0,而在移动之后,这个值就变成了##########e-18。因此,所有其他的值都变成了像##########e+02这样的东西。

这是我正在运行的代码。

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
x, y = mnist['data'], mnist['target']
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
image = x_train[99]
reshaped = image.reshape(28,28)
reshaped_2 = reshaped.reshape(784,)
from scipy.ndimage.interpolation import shift
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0)[8:11,:])

这里是输出

[[  0.   0.   0.   0.   0.   0.   0.   0.  32. 109. 109. 110. 109. 109.
109. 255. 253. 253. 253. 255. 211. 109.  47.   0.   0.   0.   0.   0.]
[  0.   0.   0.  32.  73.  73. 155. 217. 227. 252. 252. 253. 252. 252.
252. 253. 252. 252. 252. 253. 252. 252. 108.   0.   0.   0.   0.   0.]
[  0.   0.   0. 109. 252. 252. 252. 236. 226. 252. 231. 217. 215. 195.
71.  72.  71.  71. 154. 253. 252. 252. 108.   0.   0.   0.   0.   0.]]
[[-1.45736740e-17  2.08908499e-18  1.97425281e-17  1.32870826e-14
2.88143171e-14  2.90612090e-14  2.63726515e-14  2.89883698e-14
3.20000000e+01  1.09000000e+02  1.09000000e+02  1.10000000e+02
1.09000000e+02  1.09000000e+02  1.09000000e+02  2.55000000e+02
2.53000000e+02  2.53000000e+02  2.53000000e+02  2.55000000e+02
2.11000000e+02  1.09000000e+02  4.70000000e+01  8.06113136e-16
-1.58946559e-16 -9.39990682e-17  2.66688532e-17 -5.77791548e-17]
[-5.61019971e-16  2.32169340e-15  7.43877530e-15  3.20000000e+01
7.30000000e+01  7.30000000e+01  1.55000000e+02  2.17000000e+02
2.27000000e+02  2.52000000e+02  2.52000000e+02  2.53000000e+02
2.52000000e+02  2.52000000e+02  2.52000000e+02  2.53000000e+02
2.52000000e+02  2.52000000e+02  2.52000000e+02  2.53000000e+02
2.52000000e+02  2.52000000e+02  1.08000000e+02  3.29017268e-16
-6.57046610e-16 -1.22504799e-16  2.64344390e-17 -1.25480283e-16]
[-2.16877621e-15  7.92064171e-15  2.39544414e-14  1.09000000e+02
2.52000000e+02  2.52000000e+02  2.52000000e+02  2.36000000e+02
2.26000000e+02  2.52000000e+02  2.31000000e+02  2.17000000e+02
2.15000000e+02  1.95000000e+02  7.10000000e+01  7.20000000e+01
7.10000000e+01  7.10000000e+01  1.54000000e+02  2.53000000e+02
2.52000000e+02  2.52000000e+02  1.08000000e+02  3.04124747e-15
3.67217141e-17 -2.67076835e-16 -1.16801314e-16 -1.39584861e-16]]

是什么导致了这种行为?这是MNIST数据集的特性吗?是我的代码出错了吗?

的答案是有可能使用矢量方法来移动存储在numpy数组中的图像进行数据增强?解决了如何更有效地进行移位操作,但它没有回答我的其他问题。

根据shift文档(重点是我的):

数组移位使用样条插值所请求的订单

order:int,可选

样条插值的阶数,默认为3。顺序必须在0-5范围内。

我不会假装知道这个插值是如何发生的,但它肯定会影响移位的值;所以,我发现设置order=0会禁用这个插值,确实如此。在您的代码中进行以下更改:

np.random.seed(42) # for reproducibility
# rest of your code as-is
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0, order=0)[8:11,:])  # order=0

结果确实是相同的(在移动过程中没有发生插值):

[[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0. 168. 253.
200.   8.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  16. 235. 253.
80.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  65. 254. 169.
23.   0.   0.   0.  10.  14.   0.   0.   0.   0.   0.   0.   0.   0.]]
[[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0. 168. 253.
200.   8.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  16. 235. 253.
80.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  65. 254. 169.
23.   0.   0.   0.  10.  14.   0.   0.   0.   0.   0.   0.   0.   0.]]

np.all(reshaped[7:10,:] == shift(reshaped, [1,0], cval=0, order=0)[8:11,:])
# True

相关内容

最新更新