我有一组大约180K的句子嵌入。我使用faiss indexxivflat索引对它们进行了索引,并使用faiss k-means聚类功能对它们进行了聚类。我有20个簇。现在我想确定集群的大小,即每个集群包含多少个元素。
我还想对集群的每个元素进行分类,所以基本上我需要:
- 决定集群的大小
- 访问集群中的每个元素并执行分类。
到目前为止,我只设法查找最接近质心的元素。下面是我的代码:
niter = 10
verbose = True
d = sentence_embeddings.shape[1]
kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose, gpu=True)
kmeans.train(sentence_embeddings)
nlist = 20 # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(sentence_embeddings)
index.add (sentence_embeddings)
D, I = index.search (kmeans.centroids, 10)
一旦你训练了你的kmeans,你就可以获得句子嵌入中每个元素最接近的质心。你可以这样做:
# I contains nearest centroid to each embedding
_, I = kmeans.index.search(sentence_embeddings, 1)
# flattening the result
I_flat = [i[0] for i in I]
不确定你的第二个问题是什么意思,但是你的每个嵌入现在都已经聚集了,标签是I_flat
中相应的条目