如何在Tensorflow.js中训练卷积生成对抗网络



我在https://www.tensorflow.org/tutorials/generative/dcgan.
虽然本教程是用python编写的,但我正试图在node.js上使用tensorflow.js来实现它。
我已经学会了如何翻译大多数使用的方法,除了实际设置以下训练步骤。

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

显然,这里并不是所有的东西都可以翻译成tensorflow.js。
到目前为止,我还不知道如何获得梯度并将其应用于优化器
我尝试过使用tf.grad&tf.grads起作用,但无效
到目前为止,我拥有的是:

function trainStep(images) {
const noise = tf.randomNormal([BATCH_SIZE, noiseDim]);
const generated = gen.apply(noise, { training: true });
const realOut = dis.apply(images, { training: true });
const genOut = dis.apply(generated, { training: true });
const genLoss = generator.loss(genOut);
const disLoss = discriminator.loss(realOut, genOut);
// now what?
}

在tensorflow.js中有比指南显示的更好的方法吗
如果有人能为我指明正确的方向,我将不胜感激。

试试TensorFlow.js:的官方代码实验室

https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html

这是针对MNIST的,但一旦您了解了这一点,就可以应用于您自己的数据集。

最新更新