从PyTorch nn检索原始数据.嵌入



我正在将一个具有5个类别(例如,汽车、公共汽车…(的数据帧传递到nn.Embedding中。

当我做embedding.parameters()时,我可以看到有5个传感器,但我如何知道哪个索引对应于原始输入(例如,汽车、公共汽车…(?

不能因为张量是未命名的(只有维度可以命名,请参阅PyTorch的命名张量(。您必须将名称保存在单独的数据容器中,例如(此处为4类别(:

import pandas as pd
import torch
df = pd.DataFrame(
{
"bus": [1.0, 2, 3, 4, 5],
"car": [6.0, 7, 8, 9, 10],
"bike": [11.0, 12, 13, 14, 15],
"train": [16.0, 17, 18, 19, 20],
}
)
df_data = df.to_numpy().T
df_names = list(df)
embedding = torch.nn.Embedding(df_data.shape[0], df_data.shape[1])
embedding.weight.data = torch.from_numpy(df_data)

现在你可以简单地将它与任何你想要的索引一起使用:

index = 1
embedding(torch.tensor(index)), df_names[index]

这将为您提供(tensor[6, 7, 8, 9, 10], "car"),因此数据和相应的列名。

最新更新