在python numpy数组中,如何知道在一张图像中哪个物体更接近一个点?

In python numpy array, how to know which object is nearer to one point in one image?

我有一个代表图像的 numpy 数组。图像有 3 种颜色:橙色(背景)、蓝色(对象 1)和绿色(对象 2)。我使用 3 个值(0、1 和 2)来表示 numpy 数组中的 3 种颜色。两个对象不重叠。

我的问题是:如何知道哪个物体更靠近图像的中心(红点)? (这里,nearer表示物体到一个物体图像中心的最近距离小于最近物体到其他物体图像中心的距离)

我的代码是这样的:

import numpy as np
from scipy import spatial
import time

sub_image1 = np.ones((30, 30, 30))
sub_image2 = np.ones((20, 10, 15))

# pad the two sub_images to same shape (1200, 1200, 1200) to simulate my 3D medical data
img_1 = np.pad(sub_image1, ((1100, 70), (1100, 70), (1100, 70)))
img_2 = np.pad(sub_image1, ((1100, 80), (1130, 60), (1170, 15)))

def nerest_dis_to_center(img):
    position = np.where(img > 0)
    coordinates = np.transpose(np.array(position))  # get the coordinates where the voxels is not 0
    cposition = np.array(img.shape) / 2  # center point position/coordinate
    distance, index = spatial.KDTree(coordinates).query(cposition)
    return distance

t1 = time.time()
d1 = nerest_dis_to_center(img_1)
d2 = nerest_dis_to_center(img_2)

if d1 > d2:
    print("img2 object is nearer")
elif d2 > d1:
    print("img1 object is nearer")
else:
    print("They are the same far")
t2 = time.time()
print("used time: ", t2-t1)
# 30 seconds

上面的代码可以运行,但是速度很慢,而且需要很大的内存(大约 30 GB)。如果你想在你的电脑上重现我的代码,你可以使用更小的形状而不是 (3200, 1200, 1200)。有没有更有效的方法来实现我的目标?

注:其实我的图是3D CT医学图,太大上传不了。图像中的物体是随机的,可能是凸的,也可能不是。这就是为什么我的实施速度很慢。这里为了弄清楚我的问题,我用二维图来说明我的方法。

这可能不是最终的解决方案或最佳的w.r.t时间,必须用实际数据进行测试。为了让我的想法通过,我选择了更小的矩阵大小并且只有 2D 案例

import numpy as np
import matplotlib.pyplot as plt


sub_image1 = np.ones((30, 30))  # 1st object
sub_image2 = np.ones((20, 10)) * 2  # 2nd object

# pad the two sub_images to same shape (120, 120)
img_1 = np.pad(sub_image1, ((110, 60), (60, 110)))
img_2 = np.pad(sub_image2, ((100, 80), (130, 60)))

final_image = img_1 + img_2  # creating final image with both objects in a background of zeros

image_center = (np.array([final_image.shape[0], final_image.shape[1]]) / 2).astype(np.int)

# mark the center
final_image[image_center[0], image_center[1]] = 10

# find the coordinates of where the objects are
first_obj_coords = np.argwhere(final_image == 1)  # could be the most time consuming operation
second_obj_coords = np.argwhere(final_image == 2) # could be the most time consuming 

# find their centers
first_obj_ctr = np.mean(first_obj_coords, axis=0)
second_obj_ctr = np.mean(second_obj_coords, axis=0)

# turn the centers to int for using them to index
first_obj_ctr = np.floor(first_obj_ctr).astype(int)
second_obj_ctr = np.floor(second_obj_ctr).astype(int)

# mark the centers of the objects
final_image[first_obj_ctr[0], first_obj_ctr[1]] = 10
final_image[second_obj_ctr[0], second_obj_ctr[1]] = 10

# calculate the distances from center to the object center
print('Distance to first object: ', np.linalg.norm(image_center - first_obj_ctr))
print('Distance to second object: ', np.linalg.norm(image_center - second_obj_ctr))

plt.imshow(final_image)
plt.show()

输出

Distance to first object:  35.38361202590826
Distance to second object:  35.17101079013795

我解决了这个问题。

因为两个三维数组太大了。所以一开始我用最近邻法将它们采样到更小的尺寸。然后继续:

import numpy as np
from scipy import spatial
import time

sub_image1 = np.ones((30, 30, 30))
sub_image2 = np.ones((20, 10, 15))

# pad the two sub_images to same shape (1200, 1200, 1200) to simulate my 3D medical data
img_1 = np.pad(sub_image1, ((1100, 70), (1100, 70), (1100, 70)))
img_2 = np.pad(sub_image1, ((1100, 80), (1130, 60), (1170, 15)))

ori_sz = np.array(img_1.shape)
trgt_sz = ori_sz / 4
zoom_seq = np.array(trgt_sz, dtype='float') / np.array(ori_sz, dtype='float')
img_1 = ndimage.interpolation.zoom(img_1, zoom_seq, order=0, prefilter=0)
img_2 = ndimage.interpolation.zoom(img_2, zoom_seq, order=0, prefilter=0)
print("it cost this secons to downsample the nearer image" + str(time.time() - t0))  # 0.8 seconds


def nerest_dis_to_center(img):
    position = np.where(img > 0)
    coordinates = np.transpose(np.array(position))  # get the coordinates where the voxels is not 0
    cposition = np.array(img.shape) / 2  # center point position/coordinate
    distance, index = spatial.KDTree(coordinates).query(cposition)
    return distance

t1 = time.time()
d1 = nerest_dis_to_center(img_1)
d2 = nerest_dis_to_center(img_2)

if d1 > d2:
    print("img2 object is nearer")
elif d2 > d1:
    print("img1 object is nearer")
else:
    print("They are the same far")
t2 = time.time()
print("used time: ", t2-t1)
# 1.1 seconds