我有一个使用tf.keras编写的自动编码器,它处理2D图像。为了训练自动编码器,我使用了一个自定义的损失函数。为了改进损失函数,我想添加两个与训练样本相关的参数。然而,这些数据对于每个样本都是不同的。因此,我的数据如下:
- 图像_1,(a_1,b_1(
- 图像_2,(a_2,b_2(
- 图像_n,(a_n,b_n(
如何将这些参数传递给自定义损失函数有诀窍吗?我试图使用两个输入和一个输出,然而,我不知道如何引用图像和参数。
提前谢谢。
如果您的数据集由样本组成:Image_1, (a_1, b_1)
。。。等等,您可以使用自定义训练循环,您将拥有所需的所有灵活性。这里有一个随机自定义损失函数和数据集的例子,因为我不知道你的项目的细节:
import tensorflow as tf
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
image_size=(28, 28),
batch_size=32)
normalization_layer = tf.keras.layers.Rescaling(1./255)
def change_inputs(images, _):
x = tf.image.resize(normalization_layer(images),[28, 28], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return x, x
def custom_loss(x, x_hat, a, b):
return tf.reduce_mean(tf.math.squared_difference(x, x_hat)) * tf.reduce_mean((a-b))
a = tf.random.normal((3670,))
b = tf.random.normal((3670,))
extra_ds = tf.data.Dataset.from_tensor_slices((a, b)).batch(32)
train_ds = train_ds.map(change_inputs)
train_dataset = tf.data.Dataset.zip((train_ds, extra_ds))
input_img = tf.keras.Input(shape=(28, 28, 3))
x = tf.keras.layers.Flatten()(input_img)
x = tf.keras.layers.Dense(28 * 28 * 3, activation='relu')(x)
output = tf.keras.layers.Reshape(target_shape=(28, 28 ,3))(x)
autoencoder = tf.keras.Model(input_img, output)
optimizer = tf.keras.optimizers.Adam()
epochs = 2
for epoch in range(epochs):
print("nStart of epoch %d" % (epoch,))
for step, x_batch_train in enumerate(train_dataset):
x, _ = x_batch_train[0]
a, b = x_batch_train[1]
with tf.GradientTape() as tape:
x_hat = autoencoder(x, training=True)
loss_value = custom_loss(x, x_hat, a, b)
grads = tape.gradient(loss_value, autoencoder.trainable_weights)
optimizer.apply_gradients(zip(grads, autoencoder.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print("Training loss (for one batch) at step %d: %.4f"% (step, float(loss_value)))
print(loss_value.numpy())
print("Seen so far: %s samples" % ((step + 1) * batch_size))