训练后如何保存gpt-2-simple模型?



我训练了gpt-2-simple聊天机器人模型,但我无法保存它。对我来说,从colab下载训练好的模型是很重要的,否则我每次都必须下载355M模型(见下面的代码)。

我尝试了各种方法来保存训练好的模型(如gpt2.saveload.save_gpt2()),但没有一个有效,我没有任何更多的想法。

我的培训代码:

%tensorflow_version 2.x
!pip install gpt-2-simple
import gpt_2_simple as gpt2
import json
gpt2.download_gpt2(model_name="355M")
raw_data = '/content/drive/My Drive/data.json'
with open(raw_data, 'r') as f:
df =json.load(f)
data = []
for x in df:
for y in range(len(x['dialog'])-1):
a = '[BOT] : ' + x['dialog'][y+1]['text']
q = '[YOU] : ' + x['dialog'][y]['text']
data.append(q)
data.append(a)
with open('chatbot.txt', 'w') as f:
for line in data:
try:
f.write(line)
f.write('n')
except:
pass
file_name = "/content/chatbot.txt"
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='fresh',
run_name='run1',
print_every=10,
sample_every=100,
save_every=100
)
while True:
ques = input("Question : ")
inp = '[YOU] : '+ques+'n'+'[BOT] :'
x = gpt2.generate(sess,
length=20,
temperature = 0.6,
include_prefix=False,
prefix=inp,
nsamples=1,
)

gpt-2-simple存储库README。md链接了一个示例Colab笔记本,其中说明如下:

gpt2.finetune的其他可选但有用的参数:

  • restore_from:设置为fresh从基础GPT-2开始训练,或设置为最新到restart从现有检查点训练。
  • run_name:检查点内的子文件夹保存模型。如果您想使用多个模型(在加载模型时还需要指定run_name),这很有用
  • overwrite:如果您想继续微调现有模型(w/restore_from='latest')而不创建副本,则设置为True

README。md还指出,模型检查点默认存储在/checkpoint/run1中,如果您想在检查点文件夹中存储/加载多个模型,则可以将run_name参数传递给finetuneload_gpt2

把这些放在一起,你应该能够从保存的模型中工作,而不是每次重新下载:

import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)
# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='latest',
run_name='run1',
overwrite=True,
print_every=10,
sample_every=100,
save_every=500
)

请参阅load_gpt2()和finetune()函数的源代码,了解更多细节。