当使用keras bert进行分类时,损失为NaN



我使用keras bert进行分类。在一些数据集上,它运行良好并计算损失,而在其他数据集上损失为NaN

不同的数据集是相似的,因为它们是原始数据集的增强版本。使用keras bert,原始数据和数据的一些增强版本运行良好,而其他增强版本的数据运行不好。

当我在keras bert运行不好的数据的增强版本上使用常规的一层BiLSTM时,它运行良好,这意味着我可以排除数据有错误或包含可能影响损失计算方式的伪值的可能性。使用中的数据有三个类。

我使用的是基于伯特的无上限

!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip

有人能告诉我为什么输球是楠吗?

inputs = model.inputs[:2]
dense = model.layers[-3].output
outputs = keras.layers.Dense(3, activation='sigmoid', kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),name = 'real_output')(dense)
decay_steps, warmup_steps = calc_train_steps(train_y.shape[0], batch_size=BATCH_SIZE,epochs=EPOCHS,)
#(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR)
model = keras.models.Model(inputs, outputs)
model.compile(AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR), loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
sess = tf.compat.v1.keras.backend.get_session()
uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.compat.v1.report_uninitialized_variables ())])
init_op = tf.compat.v1.variables_initializer([v for v in tf.compat.v1.global_variables() if v.name.split(':')[0] in uninitialized_variables])
sess.run(init_op)
model.fit(train_x,train_y,epochs=EPOCHS,batch_size=BATCH_SIZE)
Train on 20342 samples
Epoch 1/10
20342/20342 [==============================] - 239s 12ms/sample - loss: nan - sparse_categorical_accuracy: 0.5572
Epoch 2/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 3/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2081
Epoch 4/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 5/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 6/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 7/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 8/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2081
Epoch 9/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 10/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
<tensorflow.python.keras.callbacks.History at 0x7f1caf9b0f90>

此外,我正在使用tensorflow 2.3.0keras 2.4.3在Google Colab上运行此程序

UPDATE

我再次查看了导致此问题的数据,发现其中一个目标标签不见了。我可能错误地编辑了它。一旦我修复了它,丢失的是NaN问题就消失了。然而,我会给我得到的答案打50分,因为这让我更好地思考我的代码。谢谢

我注意到您的代码中有一个问题,但我不确定这是否是主要原因;如果你能提供一些可复制的代码,那就更好了。

在上面的代码片段中,您在最后一层激活中使用unit < 1设置了sigmoid,这表明问题数据集可能是多标签,这就是为什么损失函数应该是binary_crossentropy,但您设置了sparse_categorical_crossentropy,这是典型的使用问题和整数标签的情况。

outputs = keras.layers.Dense(3, activation='sigmoid',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model = keras.models.Model(inputs, outputs)
model.compile(AdamWarmup(decay_steps=decay_steps, 
warmup_steps=warmup_steps, lr=LR),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])

因此,如果您的问题数据集是带有最后一层unit = 3多标签,那么设置应该更像

outputs = keras.layers.Dense(3, activation='sigmoid',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model.compile(AdamWarmup(decay_steps=decay_steps, 
warmup_steps=warmup_steps, lr=LR),
loss='binary_crossentropy',
metrics=['accuracy'])

但是,如果问题集是多类问题,并且目标标签是整数(unit = 3(,则设置应该更像如下所示:

outputs = keras.layers.Dense(3, activation='softmax',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model.compile(AdamWarmup(decay_steps=decay_steps, 
warmup_steps=warmup_steps, lr=LR),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])

最新更新