sklearn KMeans Clustering - 哪个时间序列在哪个集群中?

sklearn KMeans Clustering - which time series is in which cluster?

我正在使用 sklearn 的 kmeans 聚类,一切正常,我只想知道哪个时间序列在哪个聚类中。你有这方面的经验吗? 例如我的集群在附图中,我想知道哪个时间序列在哪个集群中(我有 143 个时间序列)。 我的时间序列存储在这个列表中:mySeries_2019_Jan 因此,在该列表中有 143 np.arrays,因此其中的元素如下所示:

mySeries_2019_Jan[0]
Out[119]: 
array([0.14117647, 0.13936652, 0.14298643, 0.14570136, 0.14298643,
       0.14751131, 0.15475113, 0.160181  , 0.15384615, 0.1438914 ,
       0.15384615, 0.14660633, 0.1520362 , 0.18914027, 0.20769231,
...

所以我想要一些简单的逻辑,只是为了知道哪个系列在哪个集群中:

cluster1_names = []
i = 0 
If mySeries_2019_Jan[i] in cluster 1:
cluster1_names.append(str("series" + i))
i+=1

这是我使用的代码(大部分是从文档中复制粘贴的 (我编辑了 kmeans 源代码,以便能够直接输入 dtw sakoe 半径,以防万一):

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import GradientBoostingClassifier


    n_clusters = math.ceil(math.sqrt(len(mySeries_2019_Jan))) 
    
    km = TimeSeriesKMeans(n_clusters = n_clusters, metric="dtw", max_iter=5, random_state=0,
                          global_constraint ="sakoe_chiba", sakoe_chiba_radius=2, verbose = 1 )
    
    labels = km.fit_predict(mySeries_2019_Jan)

为了得到聚类图片我也复制粘贴了一些我不是很懂的代码:

som_x = som_y = math.ceil(math.sqrt(math.sqrt(len(mySeries_2019_Jan))))
        plot_count = math.ceil(math.sqrt(n_clusters))
    
    fig, axs = plt.subplots(plot_count,plot_count,figsize=(25,25))
    fig.suptitle('Clusters')
    row_i=0
    column_j=0
    # For each label there is,
    # plots every series with that label
    for label in set(labels):
        cluster = []
        for i in range(len(labels)):
                if(labels[i]==label):
                    axs[row_i, column_j].plot(mySeries_2019_Jan[i],c="gray",alpha=0.4)
                    cluster.append(mySeries_2019_Jan[i])
        if len(cluster) > 0:
            axs[row_i, column_j].plot(np.average(np.vstack(cluster),axis=0),c="red")
        axs[row_i, column_j].set_title("Cluster "+str(row_i*som_y+column_j))
        column_j+=1
        if column_j%plot_count == 0:
            row_i+=1
            column_j=0
            
    plt.show()

Picture of Clusters

__

我现在如何获取哪个时间序列在哪个集群中的信息?

您可以在检索群集分组以进行绘图时保存它们。

som_x = som_y = math.ceil(math.sqrt(math.sqrt(len(mySeries_2019_Jan))))
        plot_count = math.ceil(math.sqrt(n_clusters))
    
    fig, axs = plt.subplots(plot_count,plot_count,figsize=(25,25))
    fig.suptitle('Clusters')
    row_i=0
    column_j=0
    # For each label there is,
    # plots every series with that label
    cluster_names = dict()
    for label in set(labels):
        cluster = []
        for i in range(len(labels)):
                if(labels[i]==label):
                    axs[row_i, column_j].plot(mySeries_2019_Jan[i],c="gray",alpha=0.4)
                    cluster.append(mySeries_2019_Jan[i])
        if len(cluster) > 0:
            axs[row_i, column_j].plot(np.average(np.vstack(cluster),axis=0),c="red")
        axs[row_i, column_j].set_title("Cluster "+str(row_i*som_y+column_j))
        cluster_names[str(row_i*som_y+column_j)] = cluster
        column_j+=1
        if column_j%plot_count == 0:
            row_i+=1
            column_j=0
            
    plt.show()

然后

cluster_names["1"]

将是标签为 1 的时间序列列表。

for label in set(labels):
        cluster = []
        cluster_names[label] = []
        for i in range(len(labels)):
                if(labels[i]==label):
                    axs[row_i, column_j].plot(mySeries_2019_Jan[i],c="gray",alpha=0.4)
                    cluster.append(mySeries_2019_Jan[i])
                    cluster_names[str(label)].append("series" + str(i))

然后

cluster_names["1"]

将是标签为 1 的时间序列的索引列表。