无法反转重塑的 numpy 数组

can't reverse reshaped numpy array

我想通过在数组上再次调用 reshape 将其重塑为原始尺寸来反转重塑的 numpy。

我有一个维度为 (x, y, z) 的数组 trian_x 然后我重塑 train_x

train_X_1 = train_X.reshape(train_X.shape[0], train_X.shape[1] * train_X.shape[2])

然后我想反转重塑的

train_X_2 = train_X_1.reshape((train_X.shape[0], train_X.shape[1], train_X.shape[2])

当我比较时

print((train_X_2 == train_X).all())

我得到 错误

我的代码有什么问题?谢谢

听起来你想展平,然后反转,然后重塑。

从数组开始:

import numpy as np 
arr = np.arange(6).reshape((2,3)) #[[0, 1, 2,], [3, 4, 5]]

我们可以使用 ravel

展平为一维数组
arr = arr.ravel() #[0,1,2,3,4,5]

然后我们可以颠倒顺序

arr = arr[::-1] #[5,4,3,2,1,0]

然后我们重塑它

arr.reshape(2,3) #[[5, 4, 3], [2, 1, 0]]

一共:

import numpy as np 
arr = np.arange(6).reshape((2,3))
arr = arr.ravel()[::-1].reshape(2,3)
print(arr)

你只是想试试这个吗:

In [184]: x = np.arange(24).reshape(2,3,4)                                                             
In [185]: x1 = x.reshape(2,12)                                                                         
In [186]: x2 = x1.reshape(2,3,4)                                                                       
In [187]: np.allclose(x,x2)                                                                            
Out[187]: True

你的 dtype 是什么? allclose 更适合花车。

In [218]: data = np.load('../Downloads/train_X.npy')                                                   
In [219]: data.shape                                                                                   
Out[219]: (97848, 20, 2)
In [220]: data.dtype                                                                                   
Out[220]: dtype('float64')
In [221]: data1 = data.reshape(data.shape[0], data.shape[1]*data.shape[2])                             
In [222]: data1.shape                                                                                  
Out[222]: (97848, 40)
In [223]: data2 = data1.reshape(data.shape)                                                            
In [224]: data2.shape                                                                                  
Out[224]: (97848, 20, 2)
In [225]: np.allclose(data, data2)                                                                     
Out[225]: False
In [226]: np.max(np.abs(data - data2))                                                                 
Out[226]: nan

In [247]: np.isnan(data).sum()                                                                         
Out[247]: 2514
In [248]: np.isnan(data2).sum()                                                                        
Out[248]: 2514

这是你的问题 - 数组包含 nan,不测试 ==。让我们在没有那些 nan:

的情况下进行比较
In [251]: np.allclose(np.nan_to_num(data),np.nan_to_num(data2))                                        
Out[251]: True