使用烧瓶在 heroku bert pytorch 模型上部署:错误:_pickle。取消酸洗错误:加载密钥无效,'v'



试图在Heroku上部署伯特模型。

import torch
import transformers
import numpy as np
from flask import Flask, render_template, request
from model import DISTILBERTBaseUncased
MAX_LEN = 320
TOKENIZER = transformers.DistilBertTokenizer.from_pretrained(
"distilbert-base-uncased", do_lower_case=True
)
DEVICE = "cpu"
MODEL = DISTILBERTBaseUncased()
MODEL.load_state_dict(torch.load("weight.bin"))
MODEL.to(DEVICE)
MODEL.eval()
app = Flask(__name__)

def sentence_prediction(sentence):
tokenizer = TOKENIZER
max_len = MAX_LEN
comment = str(sentence)
comment = " ".join(comment.split())
inputs = tokenizer.encode_plus(
comment,
None,
add_special_tokens=True,
max_length=max_len,
pad_to_max_length=True,
)
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0)
ids = ids.to(DEVICE, dtype=torch.long)
mask = mask.to(DEVICE, dtype=torch.long)
outputs = MODEL(ids=ids, mask=mask)
outputs = torch.sigmoid(outputs).cpu().detach().numpy()
return outputs[0][0]

@app.route("/")
def index_page():
return render_template("index.html")

@app.route("/model")
def models():
return render_template("model.html")

@app.route("/predict", methods=["POST", "GET"])
def predict():
if request.method == "POST":
sentence = request.form.get("text")
Toxic_prediction = sentence_prediction(sentence)
return render_template(
"index.html", prediction_text=np.round((Toxic_prediction * 100), 2)
)
return render_template("index.html", prediction_text="")

if __name__ == "__main__":
app.run(debug=True)

错误

MODEL.load_state_dict(torch.load("weight.bin"((

2020-05-18T06:32:32.134536+00:00 app[web.1]:文件"/app/.heroku/python/lib/python3.7/site-packages/torch/serialization.py",第 593 行,加载中

2020-05-18T06:32:32.134536+00:00 app[web.1]:返回_legacy_load(opened_file、map_location、pickle_module、**pickle_load_args(

2020-05-18T06:32:32.134536+00:00 app[web.1]:文件"/app/.heroku/python/lib/python3.7/site-packages/torch/serialization.py",第 763 行,_legacy_load

2020-05-18T06:32:32.134537+00:00 app[web.1]:magic_number = pickle_module.load(f, **pickle_load_args(

2020-05-18T06:32:32.134537+00:00 app[web.1]:_pickle。取消酸洗错误:加载键无效,"v"。

  1. 代码在本地运行良好。
  2. Heroku 的部署方法是 Github
  3. 重量.bin大小为 255 MB
  4. 烧瓶 API 在本地主机中工作正常

检查错误 1.MODEL.load_state_dict(torch.load("weight.bin"(( --> 您应该在下面正确使用或检查字母。

model.load_state_dict(torch.load(model_state_dict))

2._pickle。取消酸洗错误:加载键无效,"v"。 --> 我认为git-lfs没有安装在您的环境中。安装后,请重试。

最新更新