如何在sklearn的k-means中检查给定向量的聚类细节
How to check the cluster details of a given vector in k-means in sklearn
我正在使用以下代码使用 k 均值聚类算法对我的词向量进行聚类。
from sklearn import cluster
model = word2vec.Word2Vec.load("word2vec_model")
X = model[model.wv.vocab]
clusterer = cluster.KMeans (n_clusters=6)
preds = clusterer.fit_predict(X)
centers = clusterer.cluster_centers_
给定 word2vec 词汇表中的一个词(例如,word_vector = model['jeep']
)我想得到它的簇 ID 和到它的簇中心的余弦距离。
我尝试了以下方法。
for i,j in enumerate(set(preds)):
positions = X[np.where(preds == i)]
print(positions)
但是,它 returns 每个簇 ID 中的所有向量,而不是我正在寻找的。
如果需要,我很乐意提供更多详细信息。
聚类后,您将获得所有输入数据的 labels_
(与输入数据的顺序相同),即 clusterer.labels_[model.wv.vocab['jeep'].index]
将为您提供 jeep
所属的聚类属于。
您可以使用 scipy.spatial.distance.cosine
计算余弦距离
cluster_index = clusterer.labels_[model.wv.vocab['jeep'].index]
print(distance.cosine(model['jeep'], centers[cluster_index]))
>> 0.6935321390628815
完整代码
我不知道你的模型是什么样的,但让我们使用 GoogleNews-vectors-negative300.bin
。
from gensim.models import KeyedVectors
from sklearn import cluster
from scipy.spatial import distance
model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
# let's use a subset to accelerate clustering
X = model[model.wv.vocab][:40000]
clusterer = cluster.KMeans (n_clusters=6)
preds = clusterer.fit_predict(X)
centers = clusterer.cluster_centers_
cluster_index = clusterer.labels_[model.wv.vocab['jeep'].index]
print(cluster_index, distance.cosine(model['jeep'], centers[cluster_index]))
这是我的尝试!
from gensim.test.utils import common_texts
from gensim.models import Word2Vec
model = Word2Vec(common_texts, size=100, window=5, min_count=1, workers=4)
from sklearn.cluster import KMeans
clustering_model = KMeans(n_clusters=2)
preds = clustering_model.fit_predict([model.wv.get_vector(w) for w in model.wv.vocab])
获取集群 ID 的预测
>>> clustering_model.predict([model.wv.get_vector('computer')])
# array([1], dtype=int32)
获取给定单词和聚类中心之间的余弦相似度
>>> from sklearn.metrics.pairwise import cosine_similarity
>>> cosine_similarity(clustering_model.cluster_centers_, [model.wv.get_vector('computer')])
# array([[-0.07410881],
[ 0.34881588]])
我正在使用以下代码使用 k 均值聚类算法对我的词向量进行聚类。
from sklearn import cluster
model = word2vec.Word2Vec.load("word2vec_model")
X = model[model.wv.vocab]
clusterer = cluster.KMeans (n_clusters=6)
preds = clusterer.fit_predict(X)
centers = clusterer.cluster_centers_
给定 word2vec 词汇表中的一个词(例如,word_vector = model['jeep']
)我想得到它的簇 ID 和到它的簇中心的余弦距离。
我尝试了以下方法。
for i,j in enumerate(set(preds)):
positions = X[np.where(preds == i)]
print(positions)
但是,它 returns 每个簇 ID 中的所有向量,而不是我正在寻找的。
如果需要,我很乐意提供更多详细信息。
聚类后,您将获得所有输入数据的 labels_
(与输入数据的顺序相同),即 clusterer.labels_[model.wv.vocab['jeep'].index]
将为您提供 jeep
所属的聚类属于。
您可以使用 scipy.spatial.distance.cosine
cluster_index = clusterer.labels_[model.wv.vocab['jeep'].index]
print(distance.cosine(model['jeep'], centers[cluster_index]))
>> 0.6935321390628815
完整代码
我不知道你的模型是什么样的,但让我们使用 GoogleNews-vectors-negative300.bin
。
from gensim.models import KeyedVectors
from sklearn import cluster
from scipy.spatial import distance
model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
# let's use a subset to accelerate clustering
X = model[model.wv.vocab][:40000]
clusterer = cluster.KMeans (n_clusters=6)
preds = clusterer.fit_predict(X)
centers = clusterer.cluster_centers_
cluster_index = clusterer.labels_[model.wv.vocab['jeep'].index]
print(cluster_index, distance.cosine(model['jeep'], centers[cluster_index]))
这是我的尝试!
from gensim.test.utils import common_texts
from gensim.models import Word2Vec
model = Word2Vec(common_texts, size=100, window=5, min_count=1, workers=4)
from sklearn.cluster import KMeans
clustering_model = KMeans(n_clusters=2)
preds = clustering_model.fit_predict([model.wv.get_vector(w) for w in model.wv.vocab])
获取集群 ID 的预测
>>> clustering_model.predict([model.wv.get_vector('computer')])
# array([1], dtype=int32)
获取给定单词和聚类中心之间的余弦相似度
>>> from sklearn.metrics.pairwise import cosine_similarity
>>> cosine_similarity(clustering_model.cluster_centers_, [model.wv.get_vector('computer')])
# array([[-0.07410881],
[ 0.34881588]])