今天,在运行model.fit()
时,我突然莫名其妙地出现了这个错误。这以前也用过,我使用的是TF2.3.0,更具体地说是它的Keras模块。该函数是在生成器内部进行验证时调用的,生成器被输入到model.predict()
中。
基本上,我加载一个检查点,继续训练网络,并对验证进行预测。
即使从头开始训练模型并擦除所有相关数据,错误也会不断发生。这就像是在某个地方对某个东西进行了硬编码,因为直到几个小时前我还能够运行model.fit()
。
我看到了几种这样的解决方案,但这些变体都不适合我,因为它们会导致更棘手的错误消息。
我甚至尝试安装一个不同版本的TF,认为这是由于一些旧版本造成的,但错误仍然存在。
我会回答我自己的问题,因为这个问题特别棘手,我在互联网上找到的解决方案都不适合我,可能是因为过时了。
我只写下相关的部分添加到代码中,可以随意添加更多的技术解释。我喜欢使用args
传递变量,但它可以在没有的情况下工作
from tensorflow.python.keras.backend import set_session
from tensorflow.keras.models import load_model
import generator # custom generator
def main(args):
# open new session and define TF graph
args.sess = tf.compat.v1.Session()
args.graph = tf.compat.v1.get_default_graph()
set_session(args.sess)
# define training generator
train_generator = generator(args.train_data)
# load model
args.model = load_model(args.model_path)
args.model.fit(train_generator)
然后,在模型预测函数中:
# In my specific case, the predict_output() function is
# called inside the generator function
def predict_output(args, x):
with args.graph.as_default():
set_session(args.sess)
y = model.predict(x)
return y