使用Pytorch提取自编码器的隐藏层特征



我遵循这个教程来训练一个自动编码器。

训练进行得很顺利。接下来,我感兴趣的是从隐藏层(在编码器和解码器之间)提取特征。

我该怎么做呢?

最干净和最直接的方法是添加用于创建部分输出的方法——这甚至可以在训练好的模型上进行后验。

from torch import Tensor
class AE(nn.Module):
def __init__(self, **kwargs):
...
def encode(self, features: Tensor) -> Tensor:
h = torch.relu(self.encoder_hidden_layer(features))
return torch.relu(self.encoder_output_layer(h))
def decode(self, encoded: Tensor) -> Tensor:
h = torch.relu(self.decoder_hidden_layer(encoded))
return torch.relu(self.decoder_output_layer(h))
def forward(self, features: Tensor) -> Tensor:
encoded = self.encode(features)
return self.decode(encoded)

您现在可以通过简单地使用相应的输入张量调用encode来查询编码器隐藏状态的模型。

如果你不想在基类中添加任何方法(我不明白为什么),你可以选择编写一个外部函数:

def get_encoder_state(model: AE, features: Tensor) -> Tensor:
return torch.relu(model.encoder_output_layer(torch.relu(model.encoder_hidden_layer(features))))

最新更新