图像未使用 DBSCAN 正确分割

Image not segmenting properly using DBSCAN

我正在尝试使用 scikitlearn 中的 DBSCAN 根据颜色分割图像。我得到的结果是 。如您所见,有 3 个集群。我的目标是将图片中的浮标分成不同的簇。但很明显,它们显示为同一个集群。我已经尝试了各种 eps 值和 min_samples 但这两个东西总是聚集在一起。我的代码是:

img= cv2.imread("buoy1.jpg) 
labimg = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)

n = 0
while(n<4):
    labimg = cv2.pyrDown(labimg)
    n = n+1

feature_image=np.reshape(labimg, [-1, 3])
rows, cols, chs = labimg.shape

db = DBSCAN(eps=5, min_samples=50, metric = 'euclidean',algorithm ='auto')
db.fit(feature_image)
labels = db.labels_

plt.figure(2)
plt.subplot(2, 1, 1)
plt.imshow(img)
plt.axis('off')
plt.subplot(2, 1, 2)
plt.imshow(np.reshape(labels, [rows, cols]))
plt.axis('off')
plt.show()

我假设这是采用欧几里得距离,因为它在实验室 space 欧几里得距离在不同颜色之间会有所不同。如果有人能给我指导,我将不胜感激。

更新: 以下答案有效。由于 DBSCAN 需要一个不超过 2 维的数组,我将列连接到原始图像并重新整形以生成 n x 5 矩阵,其中 n 是 x 维乘以 y 维。这似乎对我有用。

indices = np.dstack(np.indices(img.shape[:2]))
xycolors = np.concatenate((img, indices), axis=-1) 
np.reshape(xycolors, [-1,5])

您需要同时使用颜色和位置

目前,您只使用了颜色。

Could you please add the enitre code in the answer? Im not able to understand where do I add the those 3 lines which have worked for you – user8306074 Sep 4 at 8:58

我来为您解答,完整版代码如下:

import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

img= cv2.imread('your image') 
labimg = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)

n = 0
while(n<4):
    labimg = cv2.pyrDown(labimg)
    n = n+1

feature_image=np.reshape(labimg, [-1, 3])
rows, cols, chs = labimg.shape

db = DBSCAN(eps=5, min_samples=50, metric = 'euclidean',algorithm ='auto')
db.fit(feature_image)
labels = db.labels_

indices = np.dstack(np.indices(labimg.shape[:2]))
xycolors = np.concatenate((labimg, indices), axis=-1) 
feature_image2 = np.reshape(xycolors, [-1,5])
db.fit(feature_image2)
labels2 = db.labels_

plt.figure(2)
plt.subplot(2, 1, 1)
plt.imshow(img)
plt.axis('off')

# plt.subplot(2, 1, 2)
# plt.imshow(np.reshape(labels, [rows, cols]))
# plt.axis('off')

plt.subplot(2, 1, 2)
plt.imshow(np.reshape(labels2, [rows, cols]))
plt.axis('off')
plt.show()