正在读取PyTorch中的.pyth文件类型



存储库中有一个预训练的模型,它的文件类型是.pyth。我在网上搜索了一下这个文件类型以及哪种语言能够读取它,但我什么都找不到。既然我使用PyTorch,有可能在PyTorch中读取这样的文件吗?此外,通常情况下,如何读取和生成?

更清楚地说,在TimeSformer模型的存储库中,预训练的模型是这种文件类型的,例如,您可以在该存储库中找到以下命令:

import torch
from timesformer.models.vit import TimeSformer
model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time',  pretrained_model='/path/to/pretrained/model.pyth')
dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)
pred = model(dummy_video,) # (2, 400)

文件扩展名可以是任何东西,它不会更改文件内容。如果运行torch.load("file.pyth"),它将加载一个权重字典。您可以在包含的回购中的代码中找到这一点。他们使用以下代码保存模型:

path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
with PathManager.open(path_to_checkpoint, "wb") as f:
torch.save(checkpoint, f)

CCD_ 2函数可以在这里找到:

def get_path_to_checkpoint(path_to_job, epoch):
"""
Get the full path to a checkpoint file.
Args:
path_to_job (string): the path to the folder of the current job.
epoch (int): the number of epoch for the checkpoint.
"""
name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
return os.path.join(get_checkpoint_dir(path_to_job), name)

因此,他们只需将扩展名为.pyth的文件名传递给torch.save

最新更新