在Bert Embeddings中搜索最近的邻居



我使用bert嵌入来使用这种方法生成类似的单词:https://gist.github.com/avidale/c6b19687d333655da483421880441950

它适用于较小的数据集,但对于较大的数据集有问题,我得到的错误是:内存错误:numpy.core_exceptions.MemoryError:无法为形状为(819827768(、数据类型为float32 的数组分配2.35 GiB

当处理具有超过20000个句子的较大数据集时。有人能提出一个好的方法来做到这一点,同时保存索引和数据,这样下次就可以轻松加载,而无需所有计算!主要代码是(您可以从上面的链接中获得完整的代码以供参考(:

from sklearn.neighbors import KDTree
import numpy as np

class ContextNeighborStorage:
def __init__(self, sentences, model):
self.sentences = sentences
self.model = model
def process_sentences(self):
result = self.model(self.sentences)
self.sentence_ids = []
self.token_ids = []
self.all_tokens = []
all_embeddings = []
for i, (toks, embs) in enumerate(tqdm(result)):
for j, (tok, emb) in enumerate(zip(toks, embs)):
self.sentence_ids.append(i)
self.token_ids.append(j)
self.all_tokens.append(tok)
all_embeddings.append(emb)
all_embeddings = np.stack(all_embeddings)
# we normalize embeddings, so that euclidian distance is equivalent to cosine distance
self.normed_embeddings = (all_embeddings.T / (all_embeddings**2).sum(axis=1) ** 0.5).T
def build_search_index(self):
# this takes some time
# I want to save this to disk, so that I can load it next time easily
self.indexer = KDTree(self.normed_embeddings)
def query(self, query_sent, query_word, k=10, filter_same_word=False):
toks, embs = self.model([query_sent])[0]
found = False
for tok, emb in zip(toks, embs):
if tok == query_word:
found = True
break
if not found:
raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
emb = emb / sum(emb**2)**0.5
if filter_same_word:
initial_k = max(k, 100)
else:
initial_k = k
di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
distances = []
neighbors = []
contexts = []
for i, index in enumerate(idx.ravel()):
token = self.all_tokens[index]
if filter_same_word and (query_word in token or token in query_word):
continue
distances.append(di.ravel()[i])
neighbors.append(token)
contexts.append(self.sentences[self.sentence_ids[index]])
if len(distances) == k:
break
return distances, neighbors, contexts

也许可以尝试使用本地敏感哈希(LSH(。我认为find-NNs是LSH的一个重要用例。

最新更新