gradienttape的loss函数返回none


def ml_1(epochs, lay1, lay2):
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(3,)),
tf.keras.layers.Dense(lay1, activation='relu'),
tf.keras.layers.Dense(lay2, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
for epoch in range(epochs):

for i in range(1, 100):
X_train, X_test, y_train, y_test = get_data(i)
# this returns (n x 3) dataframe of digits for x and series of boolean for y, both x and y are converted into tensors using tf.convert_to_tensor()

with tf.GradientTape() as tape:
logits = model(X_train, training=True)  
loss_value = loss_fn(y_train, logits)
print(loss_value)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

我有这个应该训练二元神经网络模型的函数。每次我用新的I调用get data时,它都会返回一个新的(X_train,X_test,y_train,y_test(。但它不起作用。print(loss_value(每次都打印nan。我做错了什么,我选择了正确的损失函数吗?

None损失值可能是因为您正在向模型提供数据。下面是一个简单的工作示例:

import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(3,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
epochs = 5
for epoch in range(epochs):
for i in range(1, 2):
X_train = tf.random.normal((100, 3))
y_train = tf.random.uniform((100, ), maxval=2, dtype=tf.int32)
with tf.GradientTape() as tape:
logits = model(X_train, training=True)  
loss_value = loss_fn(y_train, logits)
print(loss_value)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
tf.Tensor(0.69102806, shape=(), dtype=float32)
tf.Tensor(0.70286894, shape=(), dtype=float32)
tf.Tensor(0.68930304, shape=(), dtype=float32)
tf.Tensor(0.70442116, shape=(), dtype=float32)
tf.Tensor(0.69840324, shape=(), dtype=float32)

因此,可以尝试输出X_trainy_train,并检查它们是否具有正确的大小以及是否包含NaN值。

最新更新