为什么 NN 的损失没有改变?



首先,当尝试实现GAN时,我发现下面的代码中的损失没有改变,这是带有相应输入的代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tensorflow.examples.tutorials.mnist import input_data
import os

def sample_z(m, n):
return np.random.uniform(-1, 1, size=[m, n])

def generator(z, reuse=False):
with tf.variable_scope("generator", reuse=reuse):
w_init = tf.contrib.layers.xavier_initializer()
dense1 = tf.layers.dense(z, 128, activation=tf.nn.relu, kernel_initializer=w_init)
o = tf.layers.dense(dense1, 784, activation=tf.nn.tanh, kernel_initializer=w_init)
return o

def discriminator(x, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
w_init = tf.contrib.layers.xavier_initializer()
dense1 = tf.layers.dense(x, 128, activation=tf.nn.relu, kernel_initializer=w_init)
dense2 = tf.layers.dense(dense1, 1, activation=tf.nn.relu, kernel_initializer=w_init)
o = tf.nn.sigmoid(dense2)
return o, dense2

def plot(samples):
...

z = tf.placeholder(tf.float32, shape=[None, 100], name='z')
x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
isTrain = tf.placeholder(dtype=tf.bool)
G_sample = generator(z, isTrain)
D_real, D_logit_real = discriminator(x)
D_fake, D_logit_fake = discriminator(G_sample, reuse=True)
D_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
T_vars = tf.trainable_variables()
D_vars = [var for var in T_vars if var.name.startswith('discriminator')]
G_vars = [var for var in T_vars if var.name.startswith('generator')]
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_vars)
batch_size = 128
z_dim = 100
mnist = input_data.read_data_sets('mnist/', one_hot=True)
i = 0
if not os.path.exists('output/GAN/'):
os.makedirs('output/GAN/')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for it in range(100000):
x_batch, _ = mnist.train.next_batch(batch_size)
_, D_loss_curr = sess.run(
[D_solver, D_loss],
feed_dict={x: x_batch, z: sample_z(batch_size, z_dim), isTrain: True}
)
_, G_loss_curr = sess.run(
[G_solver, G_loss],
feed_dict={z: sample_z(batch_size, z_dim), isTrain: True}
)
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={z: sample_z(16, z_dim)})
fig = plot(samples)
plt.savefig('output/GAN/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)
i += 1
print('Iter: {:4} D_loss: {:.4} G_loss: {:.4}'.format(
it, D_loss_curr, G_loss_curr))

输出如下所示:

Iter:    0 D_loss: 1.379 G_loss: 0.6931
Iter: 1000 D_loss: 0.6937 G_loss: 0.6931
Iter: 2000 D_loss: 0.6933 G_loss: 0.6931
Iter: 3000 D_loss: 1.386 G_loss: 0.6931
Iter: 4000 D_loss: 1.386 G_loss: 0.6931
Iter: 5000 D_loss: 1.386 G_loss: 0.6931

特别是,当我将函数"生成器"和"描述符"更改为下面时,损失发生了变化。

def generator(z, isTrain, reuse=False):
with tf.variable_scope("generator", reuse=reuse):
w_init = tf.contrib.layers.xavier_initializer()
dense1 = tf.layers.dense(z, 128, kernel_initializer=w_init)
relu1 = tf.nn.relu(dense1)
dense2 = tf.layers.dense(relu1, 784, kernel_initializer=w_init)
o = tf.nn.tanh(dense2)
return o

def discriminator(x, isTrain, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
w_init = tf.contrib.layers.xavier_initializer()
dense1 = tf.layers.dense(x, 128, kernel_initializer=w_init)
lrelu1 = lrelu(dense1, 0.2)
dense2 = tf.layers.dense(lrelu1, 1, kernel_initializer=w_init)
o = tf.nn.sigmoid(dense2)
return o, dense2

有人可以告诉我为什么会这样,非常感谢。

  1. reuse参数必须TrueFalseNonetf.variable_scopeG_sample = generator(z, isTrain)会 不可避免地引发错误。
  2. 由于没有丢失,批量规范化...模型中的层,isTrain不是必需的。
  3. 我们通过tf.variable_scope重用具有不同输入的鉴别器网络, 当reuse为 true 时,我们使用具有相同名称的 create 变量而不是 create one。 参看。 Intro_to_GANs_Solution.ipynb
  4. 无论是在密集层中隐式添加激活还是显式调用tf.nn.relutf.nn.tanh都没有关系

最新更新