Tensorflow keras,使用MSE时没有梯度误差



我正在尝试从tensorflow的代码"从头开始编写训练循环";我自己做了一些改变。我将损失函数从SparseCategoricalCrossentropy更改为MeanSquaredError。我还通过添加一个新的Lambda层来计算损失,从而改变了模型的架构。但是,我有一个值错误,没有为变量提供梯度。是否有任何方法,我可以使代码运行与MSE?

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
final_outputs = layers.Lambda(lambda x: tf.math.argmax(x, axis = -1))(outputs)
model = keras.Model(inputs=inputs, outputs=final_outputs)
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.MeanSquaredError()
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

epochs = 2
for epoch in range(epochs):
print("nStart of epoch %d" % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

argmaxops不可导。要使用整数标签和MSE损失,您需要将标签y_trainy_val转换为整数。

y_train = np.argmax(y_train, axis=-1)
y_val = np.argmax(y_val, axis=-1)

并调整输出层为输出整数标签

outputs = layers.Dense(1, name="predictions")(x2)

相关内容

  • 没有找到相关文章

最新更新