是否有方法生成nn.有效地嵌入使用for循环?



我是Pytorch新手,我想知道我是否可以生成nn。使用for循环有效嵌入。

class Example(nn.Module):
def __init__(self):
self.A_embed_dim = 3
self.B_embed_dim = 3
self.C_embed_dim = 5
self.A_embedding = nn.Embedding(
df.A.max() + 1, self.A_embed_dim
)
self.B_embedding = nn.Embedding(
df.B.max() + 1, self.B.embed_dim
)
self.C_embedding = nn.Embedding(
df.C.max() + 1, self.C.embed_dim
)

在这种情况下,只存在3列,并且很容易生成嵌入。但是如果数据框中有更多的列(例如,从A到P有16列),代码就很长,看起来不干净。有没有办法创建多个nn。嵌入使用for循环?

是的,您可以使用模块列表或模块字典来这样做。

ModuleList:

self.embeddings = nn.ModuleList(
[
nn.Embedding(vocab_size, dim) 
for vocab_size, dim in embedding_args 
]
)
# embedding_args = [(5,10), (2, 8)]

最新更新