Problems with Python tensorflow



我是一个编程专家,曾尝试学习机器学习。我在Python中使用了tensorflow。这是使用官方tensorflow指南编写的代码(但不是100%复制(https://www.tensorflow.org/guide/basics)。我看不到训练后的最终结果。我尝试过两种训练方法,但都有相同的问题。有人能帮我吗?

import matplotlib as mp
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pl
mp.rcParams["figure.figsize"] = [20, 10]
precision = 500
x = tf.linspace(-10.0, 10.0, precision)

def y(x): return 4 * np.sin(x - 1) + 3
newY = y(x) + tf.random.normal(shape=[precision])
class Model(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units = units, activation = tf.nn.relu, kernel_initializer=tf.random.normal, bias_initializer=tf.random.normal)
self.dense2 = tf.keras.layers.Dense(1)

def __call__(self, x, training = True):
x = x[:, tf.newaxis]
x = self.dense1(x)
x = self.dense2(x)
return tf.squeeze(x, axis=1)
model = Model(164)
pl.plot(x, y(x), label = "origin")
pl.plot(x, newY, ".", label = "corrupted")
pl.plot(x, model(x), label = "before training")
"""                                                     The first method
vars = model.variables
optimizer = tf.optimizers.SGD(learning_rate = 0.01)
for i in range(1000):
with tf.GradientTape() as tape:
prediction = model(x)
error = (newY-prediction)**2
mean_error = tf.reduce_mean(error)
gradient = tape.gradient(mean_error, vars)
optimizer.apply_gradients(zip(gradient, vars))
"""
model.compile(loss = tf.keras.losses.MSE, optimizer = tf.optimizers.SGD(learning_rate = 0.01))
model.fit(x, newY, epochs=100,batch_size=32,verbose=0)
pl.plot(x, model(x), label = "after training")
pl.legend()
pl.show()

我复制了你的代码并对其进行了研究。你的模型在训练过程中返回了NaN损失,我删除了内核和偏置初始化器,它正常工作。现在我不知道你的初始化出了什么问题。似乎有些权重是用NaN初始化的,然后使预测变成NaN,因此无法绘制它们。

更新:使用初始化器模块(如tensorflow.initializerstensorflow.keras.initializers,而不是tensorflow.random(。例如,使用kernel_initializer=tf.initializers.random_normal而不是现有的。

正如我所看到的,您的第三个图和第四个图是相同的。它们是pl.plot(x, model(x), label = "before training")pl.plot(x, model(x), label = "after training")您可以计算出两个图的x轴和y轴数据是相同的。

希望我的回答对你有帮助!

相关内容

  • 没有找到相关文章

最新更新