在 MNIST 数据上使用 shift()。得到奇怪的结果
Using shift() on MNIST data. Getting strange results
我正在尝试在 MNIST 图像上使用 shift()
函数。
但是,不知何故,当我查看原始数据和移位数据时,看起来恰好为零的移位值正在变成非常小的非零值而不是零。这方面的一个例子是,在移动之前该值是零,而在移动之后该值类似于 ##########e-18
。因此,所有其他值都变成了 ##########e+02
.
之类的东西
这是我的代码 运行。
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
x, y = mnist['data'], mnist['target']
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
image = x_train[99]
reshaped = image.reshape(28,28)
reshaped_2 = reshaped.reshape(784,)
from scipy.ndimage.interpolation import shift
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0)[8:11,:])
这是输出
[[ 0. 0. 0. 0. 0. 0. 0. 0. 32. 109. 109. 110. 109. 109.
109. 255. 253. 253. 253. 255. 211. 109. 47. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 32. 73. 73. 155. 217. 227. 252. 252. 253. 252. 252.
252. 253. 252. 252. 252. 253. 252. 252. 108. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 109. 252. 252. 252. 236. 226. 252. 231. 217. 215. 195.
71. 72. 71. 71. 154. 253. 252. 252. 108. 0. 0. 0. 0. 0.]]
[[-1.45736740e-17 2.08908499e-18 1.97425281e-17 1.32870826e-14
2.88143171e-14 2.90612090e-14 2.63726515e-14 2.89883698e-14
3.20000000e+01 1.09000000e+02 1.09000000e+02 1.10000000e+02
1.09000000e+02 1.09000000e+02 1.09000000e+02 2.55000000e+02
2.53000000e+02 2.53000000e+02 2.53000000e+02 2.55000000e+02
2.11000000e+02 1.09000000e+02 4.70000000e+01 8.06113136e-16
-1.58946559e-16 -9.39990682e-17 2.66688532e-17 -5.77791548e-17]
[-5.61019971e-16 2.32169340e-15 7.43877530e-15 3.20000000e+01
7.30000000e+01 7.30000000e+01 1.55000000e+02 2.17000000e+02
2.27000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.29017268e-16
-6.57046610e-16 -1.22504799e-16 2.64344390e-17 -1.25480283e-16]
[-2.16877621e-15 7.92064171e-15 2.39544414e-14 1.09000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.36000000e+02
2.26000000e+02 2.52000000e+02 2.31000000e+02 2.17000000e+02
2.15000000e+02 1.95000000e+02 7.10000000e+01 7.20000000e+01
7.10000000e+01 7.10000000e+01 1.54000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.04124747e-15
3.67217141e-17 -2.67076835e-16 -1.16801314e-16 -1.39584861e-16]]
是什么导致了这种行为?这是 MNIST 数据集的一个特点吗?是我的代码出错了吗?
中的答案解决了如何更有效地进行移位操作,但没有回答我的其他问题。
根据 shift
documentation(强调我的):
The array is shifted using spline interpolation of the requested order
和
order : int, optional
The order of the spline interpolation, default is 3. The order has to be in the range 0-5.
我不会假装确切地知道这种插值是如何发生的,但它肯定会影响移位的值;所以,我发现设置 order=0
会禁用此插值,而且确实如此。在您的代码中进行以下更改:
np.random.seed(42) # for reproducibility
# rest of your code as-is
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0, order=0)[8:11,:]) # order=0
结果确实一样(平移时没有插值):
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
和
np.all(reshaped[7:10,:] == shift(reshaped, [1,0], cval=0, order=0)[8:11,:])
# True
我正在尝试在 MNIST 图像上使用 shift()
函数。
但是,不知何故,当我查看原始数据和移位数据时,看起来恰好为零的移位值正在变成非常小的非零值而不是零。这方面的一个例子是,在移动之前该值是零,而在移动之后该值类似于 ##########e-18
。因此,所有其他值都变成了 ##########e+02
.
这是我的代码 运行。
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
x, y = mnist['data'], mnist['target']
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
image = x_train[99]
reshaped = image.reshape(28,28)
reshaped_2 = reshaped.reshape(784,)
from scipy.ndimage.interpolation import shift
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0)[8:11,:])
这是输出
[[ 0. 0. 0. 0. 0. 0. 0. 0. 32. 109. 109. 110. 109. 109.
109. 255. 253. 253. 253. 255. 211. 109. 47. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 32. 73. 73. 155. 217. 227. 252. 252. 253. 252. 252.
252. 253. 252. 252. 252. 253. 252. 252. 108. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 109. 252. 252. 252. 236. 226. 252. 231. 217. 215. 195.
71. 72. 71. 71. 154. 253. 252. 252. 108. 0. 0. 0. 0. 0.]]
[[-1.45736740e-17 2.08908499e-18 1.97425281e-17 1.32870826e-14
2.88143171e-14 2.90612090e-14 2.63726515e-14 2.89883698e-14
3.20000000e+01 1.09000000e+02 1.09000000e+02 1.10000000e+02
1.09000000e+02 1.09000000e+02 1.09000000e+02 2.55000000e+02
2.53000000e+02 2.53000000e+02 2.53000000e+02 2.55000000e+02
2.11000000e+02 1.09000000e+02 4.70000000e+01 8.06113136e-16
-1.58946559e-16 -9.39990682e-17 2.66688532e-17 -5.77791548e-17]
[-5.61019971e-16 2.32169340e-15 7.43877530e-15 3.20000000e+01
7.30000000e+01 7.30000000e+01 1.55000000e+02 2.17000000e+02
2.27000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.29017268e-16
-6.57046610e-16 -1.22504799e-16 2.64344390e-17 -1.25480283e-16]
[-2.16877621e-15 7.92064171e-15 2.39544414e-14 1.09000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.36000000e+02
2.26000000e+02 2.52000000e+02 2.31000000e+02 2.17000000e+02
2.15000000e+02 1.95000000e+02 7.10000000e+01 7.20000000e+01
7.10000000e+01 7.10000000e+01 1.54000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.04124747e-15
3.67217141e-17 -2.67076835e-16 -1.16801314e-16 -1.39584861e-16]]
是什么导致了这种行为?这是 MNIST 数据集的一个特点吗?是我的代码出错了吗?
根据 shift
documentation(强调我的):
The array is shifted using spline interpolation of the requested order
和
order : int, optional
The order of the spline interpolation, default is 3. The order has to be in the range 0-5.
我不会假装确切地知道这种插值是如何发生的,但它肯定会影响移位的值;所以,我发现设置 order=0
会禁用此插值,而且确实如此。在您的代码中进行以下更改:
np.random.seed(42) # for reproducibility
# rest of your code as-is
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0, order=0)[8:11,:]) # order=0
结果确实一样(平移时没有插值):
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
和
np.all(reshaped[7:10,:] == shift(reshaped, [1,0], cval=0, order=0)[8:11,:])
# True