我要做的是在 pickle 对象中加载一个机器学习模型以生成摘要,这样当我将代码部署到我的 Web 应用程序时,它就不会一遍又一遍地手动加载。这需要相当多的时间,我不能让用户在模型加载时等待 10-15 分钟,然后生成摘要。
import cPickle as pickle
from skip_thoughts import configuration
from skip_thoughts import encoder_manager
import en_coref_md
def load_models():
VOCAB_FILE = "skip_thoughts_uni/vocab.txt"
EMBEDDING_MATRIX_FILE = "skip_thoughts_uni/embeddings.npy"
CHECKPOINT_PATH = "skip_thoughts_uni/model.ckpt-501424"
encoder = encoder_manager.EncoderManager()
print "loading skip model"
encoder.load_model(configuration.model_config(),
vocabulary_file=VOCAB_FILE,
embedding_matrix_file=EMBEDDING_MATRIX_FILE,
checkpoint_path=CHECKPOINT_PATH)
print "loaded"
return encoder
encoder= load_models()
print "Starting cPickle dumping"
pickle.dump(encoder, open('/path_to_loaded_model/loaded_model.pkl', "wb"))
print "pickle.dump executed"
print "starting cpickle loading"
loaded_model = pickle.load(open('loaded_model.pkl', 'r'))
print "pickle load done"
cPickle最初是pickle,但他们都没有在足够的时间内这样做。我第一次尝试这样做时,正在创建的泡菜文件是 11.2GB,我认为这太大了。与此同时,它花了一个多小时使我的电脑变得无用。而且代码没有完成执行,我强制重新启动我的电脑,因为它花费的时间太长。
如果有人能帮忙,我将不胜感激。
我建议检查存储到 hdf5 中是否提高了性能:
写入 hdf5:
with h5py.File('model.hdf5', 'w') as f:
for var in tf.trainable_variables():
key = var.name.replace('/', ' ')
value = session.run(var)
f.create_dataset(key, data=value)
从 hdf5 读取:
with h5py.File('model.hdf5', 'r') as f:
for (name, val) in f.items()
name = name.replace(' ', '/')
val = np.array(val)
session.run(param_setters[name][0], { param_setters[name][1]: val })
来源:
https://www.tensorflow.org/tutorials/keras/save_and_restore_models
https://geekyisawesome.blogspot.com/2018/06/savingloading-tensorflow-model-using.html