我正在计算嵌入向量之间的相似性 我的矩阵形状是 (16480,300( --> vecs
vecs[0]
array([ 0.10071956, 3.8815327 , 0.12835003, -0.31677222, 0.70524615,
0.65897983, -0.7154368 , 4.49739 , 0.77070695, -2.3327951 ,
-3.7463412 , 0.8334273 , 2.2104564 , -2.0296195 , 0.6603169 ,
-3.0648541 , -1.9763994 , 3.8416848 , -0.22661261, 0.4862857 ,....]
我正在使用hnswlib进行相似性的近似计算 n=10 的输出计算是形状为 (16480,10( -->标签的矩阵 标签矩阵的每一行都是 VECS 矩阵中最相似的向量 每列表示 VECS 矩阵中的向量索引
labels[0]
array([ 7791, 1593, 3561, 2280, 2920, 3588, 13151, 5673, 7562,
4148], dtype=uint64
我有一个 DF 用于存储与向量矩阵对应的"str"值 DF['ind']
0 1.1000659
1 1.100087
2 1.1001568
3 1.1008761
4 1.1018004
16476 1.992905
16477 1.993998
16478 1.995835
16479 1.99836
16480 1.999198
Name: ind, Length: 16481, dtype: object
我的目标是将标签矩阵映射到"str"的json,目的是写入mongoDB作为:
{'1.1000659 ' : [{'1.00xxx','1.0xxx'...n10}]
'1.xx': ....n10}
现在矩阵很小,可以扩展到 500k,因此索引将需要更长的时间
为了获取索引的字符串值,我正在运行以下代码:
{df.iloc[i]['ind']:df.iloc[labels[i]]['ind'] for i in range(labels.shape[0])}
运行时间 ~ 12 表示 16,400 行 还有另一种映射"矢量化"的方法吗? 谢谢。
您可以尝试将索引df.ind
与labels
一起使用,一旦添加了带有None
的维度。不确定您的确切预期输出,但如下所示:
#dummy input
np.random.seed(16)
df = pd.DataFrame({'ind': ['1.001', '1.002', '1.003', '1.004', '1.005',
'1.006', '1.007', '1.008', '1.009', '1.010']})
labels = np.random.randint(0,9, size=(10, 4))
# see what does the indexing
print (df.ind[:,None][labels].reshape(labels.shape).tolist())
[['1.006' '1.002' '1.005' '1.005']
['1.001' '1.001' '1.009' '1.003']
['1.005' '1.001' '1.002' '1.003']
['1.005' '1.001' '1.006' '1.003']
['1.004' '1.009' '1.003' '1.006']
['1.005' '1.002' '1.009' '1.005']
['1.006' '1.007' '1.008' '1.006']
['1.009' '1.001' '1.007' '1.009']
['1.006' '1.003' '1.005' '1.003']
['1.002' '1.009' '1.008' '1.002']]
# create the result you want
d = {ind: val for ind, val in zip(df.ind, df.ind[:,None][labels].reshape(labels.shape).tolist())}
print (d)
{'1.001': ['1.006', '1.002', '1.005', '1.005'],
'1.002': ['1.001', '1.001', '1.009', '1.003'],
'1.003': ['1.005', '1.001', '1.002', '1.003'],
'1.004': ['1.005', '1.001', '1.006', '1.003'],
'1.005': ['1.004', '1.009', '1.003', '1.006'],
'1.006': ['1.005', '1.002', '1.009', '1.005'],
'1.007': ['1.006', '1.007', '1.008', '1.006'],
'1.008': ['1.009', '1.001', '1.007', '1.009'],
'1.009': ['1.006', '1.003', '1.005', '1.003'],
'1.010': ['1.002', '1.009', '1.008', '1.002']}