将字符串转换为字节以用于 pytorch 加载器



下载pytorch模型路径的方法不在我的控制范围内,我正在尝试找出一种将下载的字符串数据转换为字节数据的方法。 下面的代码从 Dropbox 下载我保存的模型,并使用带有 utf-8 编码的字节对字符串进行编码。问题是当我将torch.load与BytesIO一起使用时,我得到一个带有无效加载键"<"的UnpicklingError。

data = bytes(self.Download("https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"), 'utf-8')
self.agent.local.load_state_dict(torch.load(BytesIO(data ), map_location=lambda storage, loc: storage))

下面的代码运行良好,直到请求被禁用,我现在正在尝试使用上面的方法。

dropbox_url = "https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"
data = requests.get(dropbox_url )
self.agent.local.load_state_dict(torch.load(BytesIO(data.content), map_location=lambda storage, loc: storage))

我只需要找出一种以正确方式将字符串转换为字节数据的方法。

我必须将字节数据转换为 base64 并以该格式保存文件。 一旦我上传到 Dropbox 并使用内置方法下载,我将 base64 文件转换回字节,它就起作用了!

import base64
from io import BytesIO
with open("checkpoint.pth", "rb") as f:
byte = f.read(1)
# Base64 Encode the bytes
data_e = base64.b64encode(byte)
filename ='base64_checkpoint.pth'
with open(filename, "wb") as output:
output.write(data_e)
# Save file to Dropbox
# Download file on server
b64_str= self.Download('url')
# String Encode to bytes
byte_data = b64_str.encode("UTF-8")
# Decoding the Base64 bytes
str_decoded = base64.b64decode(byte_data)
# String Encode to bytes
byte_decoded = str_decoded.encode("UTF-8")
# Decoding the Base64 bytes
decoded = base64.b64decode(byte_decoded)
torch.load(BytesIO(decoded))

最新更新