在 python ndarray 中查找重复行的索引
Find indices of duplicated rows in python ndarray
我编写了 for 循环以枚举包含 n 行 28x28 像素值的多维 ndarray。
我正在寻找重复的每一行的索引以及没有冗余的重复项的索引。
我找到了这段代码 here(感谢 unutbu)并对其进行了修改以读取 ndarray,它在 70% 的时间内有效,但在 30% 的时间内将错误的图像识别为重复图像。
如何改进以检测正确的行?
def overlap_same(arr):
seen = []
dups = collections.defaultdict(list)
for i, item in enumerate(arr):
for j, orig in enumerate(seen):
if np.array_equal(item, orig):
dups[j].append(i)
break
else:
seen.append(item)
return dups
例如return overlap_same(火车) returns:
defaultdict(<type 'list'>, {34: [1388], 35: [1815], 583: [3045], 3208:
[4426], 626: [824], 507: [4438], 188: [338, 431, 540, 757, 765, 806,
808, 834, 882, 1515, 1539, 1715, 1725, 1789, 1841, 2038, 2081, 2165,
2170, 2300, 2455, 2683, 2733, 2957, 3290, 3293, 3311, 3373, 3446, 3542,
3565, 3890, 4110, 4197, 4206, 4364, 4371, 4734, 4851]})
在 matplotlib 上绘制一些正确案例的样本给出:
fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[35])
a.set_title('train[35]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[1815])
a.set_title('train[1815]')
plt.show
哪个是正确的
但是:
fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[3208])
a.set_title('train[3208]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[4426])
a.set_title('train[4426]')
plt.show
不正确,因为它们不匹配
样本数据(火车[:3])
array([[[-0.5 , -0.5 , -0.5 , ..., 0.48823529,
0.5 , 0.17058824],
[-0.5 , -0.5 , -0.5 , ..., 0.48823529,
0.5 , -0.0372549 ],
[-0.5 , -0.5 , -0.5 , ..., 0.5 ,
0.47647059, -0.24509804],
...,
[-0.49215686, 0.34705883, 0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.31176472, 0.44901961, 0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.11176471, 0.5 , 0.49215686, ..., -0.5 ,
-0.5 , -0.5 ]],
[[-0.24509804, 0.2764706 , 0.5 , ..., 0.5 ,
0.25294119, -0.36666667],
[-0.5 , -0.47254902, -0.02941176, ..., 0.20196079,
-0.46862745, -0.5 ],
[-0.49215686, -0.5 , -0.5 , ..., -0.47647059,
-0.5 , -0.49607843],
...,
[-0.49215686, -0.49607843, -0.5 , ..., -0.5 ,
-0.5 , -0.49215686],
[-0.5 , -0.5 , -0.26862746, ..., 0.13137256,
-0.46470588, -0.5 ],
[-0.30000001, 0.11960784, 0.48823529, ..., 0.5 ,
0.28431374, -0.24117647]],
[[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
...,
[-0.5 , -0.5 , -0.5 , ..., 0.48431373,
0.5 , 0.31568629],
[-0.5 , -0.49215686, -0.5 , ..., 0.49215686,
0.5 , 0.04901961],
[-0.5 , -0.5 , -0.5 , ..., 0.04117647,
-0.17450981, -0.45686275]]], dtype=float32)
numpy_indexed 包有很多功能可以有效地解决这些类型的问题。
例如,(不同于 numpy 的内置独特)这将找到您的独特图像:
import numpy_indexed as npi
unique_training_images = npi.unique(train)
或者如果你想找到每个唯一组的所有索引,你可以使用:
indices = npi.group_by(train).split(np.arange(len(train)))
请注意,这些函数不像您原来的 post 那样具有二次时间复杂度,并且是完全矢量化的,因此很可能效率更高。此外,与 pandas 不同,它没有首选数据格式,并且完全支持 nd 数组,因此作用于形状为 [n_images, 28, 28] 'just works'.[= 的数组13=]
我编写了 for 循环以枚举包含 n 行 28x28 像素值的多维 ndarray。
我正在寻找重复的每一行的索引以及没有冗余的重复项的索引。
我找到了这段代码 here(感谢 unutbu)并对其进行了修改以读取 ndarray,它在 70% 的时间内有效,但在 30% 的时间内将错误的图像识别为重复图像。
如何改进以检测正确的行?
def overlap_same(arr):
seen = []
dups = collections.defaultdict(list)
for i, item in enumerate(arr):
for j, orig in enumerate(seen):
if np.array_equal(item, orig):
dups[j].append(i)
break
else:
seen.append(item)
return dups
例如return overlap_same(火车) returns:
defaultdict(<type 'list'>, {34: [1388], 35: [1815], 583: [3045], 3208:
[4426], 626: [824], 507: [4438], 188: [338, 431, 540, 757, 765, 806,
808, 834, 882, 1515, 1539, 1715, 1725, 1789, 1841, 2038, 2081, 2165,
2170, 2300, 2455, 2683, 2733, 2957, 3290, 3293, 3311, 3373, 3446, 3542,
3565, 3890, 4110, 4197, 4206, 4364, 4371, 4734, 4851]})
在 matplotlib 上绘制一些正确案例的样本给出:
fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[35])
a.set_title('train[35]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[1815])
a.set_title('train[1815]')
plt.show
哪个是正确的
但是:
fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[3208])
a.set_title('train[3208]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[4426])
a.set_title('train[4426]')
plt.show
不正确,因为它们不匹配
样本数据(火车[:3])
array([[[-0.5 , -0.5 , -0.5 , ..., 0.48823529,
0.5 , 0.17058824],
[-0.5 , -0.5 , -0.5 , ..., 0.48823529,
0.5 , -0.0372549 ],
[-0.5 , -0.5 , -0.5 , ..., 0.5 ,
0.47647059, -0.24509804],
...,
[-0.49215686, 0.34705883, 0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.31176472, 0.44901961, 0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.11176471, 0.5 , 0.49215686, ..., -0.5 ,
-0.5 , -0.5 ]],
[[-0.24509804, 0.2764706 , 0.5 , ..., 0.5 ,
0.25294119, -0.36666667],
[-0.5 , -0.47254902, -0.02941176, ..., 0.20196079,
-0.46862745, -0.5 ],
[-0.49215686, -0.5 , -0.5 , ..., -0.47647059,
-0.5 , -0.49607843],
...,
[-0.49215686, -0.49607843, -0.5 , ..., -0.5 ,
-0.5 , -0.49215686],
[-0.5 , -0.5 , -0.26862746, ..., 0.13137256,
-0.46470588, -0.5 ],
[-0.30000001, 0.11960784, 0.48823529, ..., 0.5 ,
0.28431374, -0.24117647]],
[[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
[-0.5 , -0.5 , -0.5 , ..., -0.5 ,
-0.5 , -0.5 ],
...,
[-0.5 , -0.5 , -0.5 , ..., 0.48431373,
0.5 , 0.31568629],
[-0.5 , -0.49215686, -0.5 , ..., 0.49215686,
0.5 , 0.04901961],
[-0.5 , -0.5 , -0.5 , ..., 0.04117647,
-0.17450981, -0.45686275]]], dtype=float32)
numpy_indexed 包有很多功能可以有效地解决这些类型的问题。
例如,(不同于 numpy 的内置独特)这将找到您的独特图像:
import numpy_indexed as npi
unique_training_images = npi.unique(train)
或者如果你想找到每个唯一组的所有索引,你可以使用:
indices = npi.group_by(train).split(np.arange(len(train)))
请注意,这些函数不像您原来的 post 那样具有二次时间复杂度,并且是完全矢量化的,因此很可能效率更高。此外,与 pandas 不同,它没有首选数据格式,并且完全支持 nd 数组,因此作用于形状为 [n_images, 28, 28] 'just works'.[= 的数组13=]