使用 k.clear_session() 和 tf.reset_default_graph() 在后续模型之间清除图形



随后加载多个模型时,我似乎无法正确清除图形。

k.clear_session()   
tf.reset_default_graph()

只是在加载第一个模型后关闭 Python 中的程序。如果删除上述行,我可以加载后续模型,但随后我遇到了内存泄漏。

>>> import keras
Using TensorFlow backend.
>>> keras.__version__
'2.2.4'
>>> import tensorflow as tf
>>> tf.__version__
'1.8.0'
>>> 


def evaluate_models(models_path_dir):
    models_paths = [os.path.join(models_path_dir, model) for model in os.listdir(models_path_dir) if model.endswith(".hdf5")]
    models_pairs = get_model_key(models_paths, global_model_keys)
    print(len(model_pairs)) #15
    for model_pair in models_pairs:
        model_path,model_key = model_pair
        img_height, img_width = 480, 480
        evaluate_validation_data(model_path, model_key)

def evaluate_validation_data(model_path,model_key):
    preprocess =  model_key
    valid_datagen = ImageDataGenerator(preprocessing_function = preprocess)
    valid_generator = valid_datagen.flow_from_directory(
    validation_data_dir,
    target_size = (img_height, img_width),
    batch_size = 30, 
    class_mode = 'categorical',
    shuffle = False)
    model = load_model(model_path)
    print("model path",model_path)
    print("image size", (img_height, img_width))
    print( model.evaluate_generator(valid_generator))
    k.clear_session()
    tf.reset_default_graph()

我对 k.clear_session() 和 tf.reset_default() 的用法不正确吗?

谢谢。

更新:

我尝试按如下方式更改我的代码,但我仍然遇到同样的问题:

def evaluate_validation_data(model_path,model_key):
        preprocess =  model_key
        valid_datagen = ImageDataGenerator(preprocessing_function = preprocess)
        valid_generator = valid_datagen.flow_from_directory(
        validation_data_dir,
        target_size = (img_height, img_width),
        batch_size = 10, 
        class_mode = 'categorical',
        shuffle = False)
        model = load_model(model_path)
        print("model path",model_path)
        print("image size", (img_height, img_width))
        print( model.evaluate_generator(valid_generator))
        k.clear_session()
        #tf.reset_default_graph()


>>> import keras
Using TensorFlow backend.
>>> keras.__version__
'2.2.4'
>>> import tensorflow as tf
>>> tf.__version__
'1.8.0'
>>> 

以下是程序执行时发生的情况:

39
Found 374 images belonging to 5 classes.
loaded model
model path E:USERTESTmodel.hdf5
image size (480, 480)
[0.5056040882665843, 0.8609625604700915]
Found 374 images belonging to 5 classes.

然后关闭

似乎

Keras 高于 2.2 和 tf 1.8 存在错误?

https://github.com/keras-team/keras/issues/10399

我需要将 Keras 降级到 2.1?

编辑:

刚刚测试过。降级它 2.1 可以处理错误。

松开tf.reset_default_graph(),你应该很好。至于内存泄漏,请确保您运行的是 Keras 2.2.4(最好是 tensorflow>=1.10 具有更好的 keras 集成),我在依次加载多个模型时遇到了类似的 Keras 2.2.2 崩溃问题,并且在我更新到 Keras 2.2.4 后它消失了。

最新更新