如何查找 keras 模型的参数数



对于前馈网络(FFN),计算参数的数量很容易。给定CNN,LSTM等,有没有一种快速的方法可以找到keras模型中的参数数量?

模型和层具有用于此目的的特殊方法:

model.count_params()

此外,要获得每个图层尺寸和参数的简短摘要,您可能会发现以下方法很有用

model.summary()
import keras.backend as K
def size(model): # Compute number of params in a model (the actual number of floats)
    return sum([np.prod(K.get_value(w).shape) for w in model.trainable_weights])

追溯print_summary()函数,Keras 开发人员计算给定model的可训练参数和non_trainable参数的数量,如下所示:

import keras.backend as K
import numpy as np
trainable_count = int(np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

鉴于K.count_params()被定义为np.prod(int_shape(x)),这个解决方案与Anuj Gupta的解决方案非常相似,除了使用set()和张量形状的检索方式。

创建网络后添加:model.summary
它将为您提供网络和参数数量的摘要。

相关内容

  • 没有找到相关文章

最新更新