使用 numpy 将矩阵 int 值映射到 str 的有效方法



我正在计算嵌入向量之间的相似性 我的矩阵形状是 (16480,300( --> vecs

vecs[0]
array([ 0.10071956,  3.8815327 ,  0.12835003, -0.31677222,  0.70524615,
0.65897983, -0.7154368 ,  4.49739   ,  0.77070695, -2.3327951 ,
-3.7463412 ,  0.8334273 ,  2.2104564 , -2.0296195 ,  0.6603169 ,
-3.0648541 , -1.9763994 ,  3.8416848 , -0.22661261,  0.4862857 ,....]

我正在使用hnswlib进行相似性的近似计算 n=10 的输出计算是形状为 (16480,10( -->标签的矩阵 标签矩阵的每一行都是 VECS 矩阵中最相似的向量 每列表示 VECS 矩阵中的向量索引

labels[0]
array([ 7791,  1593,  3561,  2280,  2920,  3588, 13151,  5673,  7562,
4148], dtype=uint64

我有一个 DF 用于存储与向量矩阵对应的"str"值 DF['ind']

0        1.1000659
1         1.100087
2        1.1001568
3        1.1008761
4        1.1018004
16476     1.992905
16477     1.993998
16478     1.995835
16479      1.99836
16480     1.999198
Name: ind, Length: 16481, dtype: object

我的目标是将标签矩阵映射到"str"的json,目的是写入mongoDB作为:

{'1.1000659 ' : [{'1.00xxx','1.0xxx'...n10}]
'1.xx': ....n10}

现在矩阵很小,可以扩展到 500k,因此索引将需要更长的时间

为了获取索引的字符串值,我正在运行以下代码:

{df.iloc[i]['ind']:df.iloc[labels[i]]['ind'] for i in range(labels.shape[0])}

运行时间 ~ 12 表示 16,400 行 还有另一种映射"矢量化"的方法吗? 谢谢。

您可以尝试将索引df.indlabels一起使用,一旦添加了带有None的维度。不确定您的确切预期输出,但如下所示:

#dummy input
np.random.seed(16)
df = pd.DataFrame({'ind': ['1.001', '1.002', '1.003', '1.004', '1.005', 
'1.006', '1.007', '1.008', '1.009', '1.010']})
labels = np.random.randint(0,9, size=(10, 4))
# see what does the indexing
print (df.ind[:,None][labels].reshape(labels.shape).tolist())
[['1.006' '1.002' '1.005' '1.005']
['1.001' '1.001' '1.009' '1.003']
['1.005' '1.001' '1.002' '1.003']
['1.005' '1.001' '1.006' '1.003']
['1.004' '1.009' '1.003' '1.006']
['1.005' '1.002' '1.009' '1.005']
['1.006' '1.007' '1.008' '1.006']
['1.009' '1.001' '1.007' '1.009']
['1.006' '1.003' '1.005' '1.003']
['1.002' '1.009' '1.008' '1.002']]
# create the result you want
d = {ind: val for ind, val in zip(df.ind, df.ind[:,None][labels].reshape(labels.shape).tolist())}
print (d)
{'1.001': ['1.006', '1.002', '1.005', '1.005'],
'1.002': ['1.001', '1.001', '1.009', '1.003'],
'1.003': ['1.005', '1.001', '1.002', '1.003'],
'1.004': ['1.005', '1.001', '1.006', '1.003'],
'1.005': ['1.004', '1.009', '1.003', '1.006'],
'1.006': ['1.005', '1.002', '1.009', '1.005'],
'1.007': ['1.006', '1.007', '1.008', '1.006'],
'1.008': ['1.009', '1.001', '1.007', '1.009'],
'1.009': ['1.006', '1.003', '1.005', '1.003'],
'1.010': ['1.002', '1.009', '1.008', '1.002']}

相关内容

最新更新