我正在尝试制作一个类似GAN的模型。但我不知道如何仅为一个模型正确地将trainable设置为False。似乎所有使用子模型的模型都受到了影响。
代码:
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense
print(tf.__version__)
def build_submodel():
inp = tf.keras.Input(shape=(3,))
x = Dense(5)(inp)
model = Model(inputs=inp, outputs=x)
return model
def build_model_A():
inp = tf.keras.Input(shape=(3,))
x = submodel(inp)
x = Dense(7)(x)
model = Model(inputs=inp, outputs=x)
return model
def build_model_B():
inp = tf.keras.Input(shape=(11,))
x = Dense(3)(inp)
x = submodel(x)
model = Model(inputs=inp, outputs=x)
return model
submodel = build_submodel()
model_A = build_model_A()
model_A.compile("adam", "mse")
model_A.summary()
submodel.trainable = False
# same result with freezing layers
# for layer in submodel.layers:
# layer.trainable = True
model_B = build_model_B()
model_B.compile("adam", "mse")
model_B.summary()
model_A.summary()
输出:
Model: "model_10"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_11 (InputLayer) [(None, 3)] 0
_________________________________________________________________
model_9 (Model) (None, 5) 20
_________________________________________________________________
dense_10 (Dense) (None, 7) 42
=================================================================
Total params: 62
Trainable params: 62
Non-trainable params: 0
_________________________________________________________________
Model: "model_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_12 (InputLayer) [(None, 11)] 0
_________________________________________________________________
dense_11 (Dense) (None, 3) 36
_________________________________________________________________
model_9 (Model) (None, 5) 20
=================================================================
Total params: 56
Trainable params: 36
Non-trainable params: 20
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_11 (InputLayer) [(None, 3)] 0
_________________________________________________________________
model_9 (Model) (None, 5) 20
_________________________________________________________________
dense_10 (Dense) (None, 7) 42
=================================================================
Total params: 62
Trainable params: 42
Non-trainable params: 20
_________________________________________________________________
首先,模型A没有不可训练的权重。但在构建模型B之后。模型A具有一些不可训练的权重。
此外,该摘要没有显示哪些层是不可训练的,只是显示不可训练参数的总数。有没有更好的方法来检查模型中哪些层被冻结?
您可以使用此函数来显示哪些层是可训练的或不是
def print_params(model):
def count_params(weights):
"""Count the total number of scalars composing the weights.
# Arguments
weights: An iterable containing the weights on which to compute params
# Returns
The total number of scalars composing the weights
"""
weight_ids = set()
total = 0
for w in weights:
if id(w) not in weight_ids:
weight_ids.add(id(w))
total += int(K.count_params(w))
return total
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
print('idttrainable : layer name')
print('-------------------------------')
for i, layer in enumerate(model.layers):
print(i,'t',layer.trainable,'t :',layer.name)
print('-------------------------------')
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
它会像这个一样输出
id trainable : layer name
-------------------------------
0 False : input_1
1 False : block1_conv1
2 False : block1_conv2
3 False : block1_pool
4 False : block2_conv1
5 False : block2_conv2
6 False : block2_pool
7 False : block3_conv1
8 False : block3_conv2
9 False : block3_conv3
10 False : block3_pool
11 False : block4_conv1
12 False : block4_conv2
13 False : block4_conv3
14 False : block4_pool
15 False : block5_conv1
16 False : block5_conv2
17 False : block5_conv3
18 False : block5_pool
19 True : global_average_pooling2d
20 True : dense
21 True : dense_1
22 True : dense_2
-------------------------------
Total params: 15,245,130
Trainable params: 530,442
Non-trainable params: 14,714,688