嗨,我做了一个项目,在自动微分中使用tensorflow。使用类似numpy的相当线性的数据集:
true_w, true_b = 7., 4.
def create_batch(batch_size=64):
x = np.random.randn(batch_size, 1)
y = np.random.randn(batch_size, 1) + true_w * x+true_b
return x, y
当我尝试用kaggle的任何其他"真实"数据集重复自动微分时,权重和偏差偏离了sklearn或numpy线性回归函数的截距和系数。甚至使用高度相关的特征。以下是使用来自Kaggles世界幸福指数2022的威士忌高威士忌低数据集。尝试过其他,但这两者之间的相关性非常高,我认为这将是最好的尝试。
X = np.array(df['Whisker-high']).reshape(-1,1)
y = np.array(df['Whisker-low'])
reg = LinearRegression(fit_intercept=True).fit(X,y)
intercept = np.round(reg.intercept_,4)
coef = np.round(reg.coef_[0],4)
iters = 100
lr = .01
w_history = []
b_history = []
true_w = coef
true_b = intercept
w = tf.Variable( 0.65)
b = tf.Variable(1.5)
for i in range(0, iters):
inds = np.random.choice(np.arange(0, len(df)), size=100, replace=True)
X = np.array(df.iloc[list(inds)]['Whisker-high']).reshape(-1,1)
y = np.array(df.iloc[list(inds)]['Whisker-low'])
x_batch = tf.convert_to_tensor(X, dtype=tf.float32)
y_batch = tf.convert_to_tensor(y, dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
y = b + w *x_batch
loss = tf.reduce_mean(tf.square( y - y_batch))
dw = tape.gradient(loss, w)
db = tape.gradient(loss, b)
del tape
w.assign_sub(lr*dw)
b.assign_sub(lr*db)
w_history.append(w.numpy())
b_history.append(b.numpy())
if i %10==0:
print('iter{}, w={}, b={}'.format(i, w.numpy(), b.numpy()))
plt.plot(range(iters), w_history, label ='learned w')
plt.plot(range(iters), b_history, label ='learned b')
plt.plot(range(iters),[true_w] *iters, label='true w')
plt.plot(range(iters),[true_b] *iters, label='true b')
plt.legend()
plt.show()
尽管使用自动微分,权重和偏差似乎确实达到了最小值,但对数据进行简单的折线图显示,如果说它代表了数据集,那就太好了。
plt.figure(figsize=(6,6))
plt.scatter(df['speeding'], df['alcohol'])
xseq = np.linspace(0, 9, num=df.shape[0])
plt.plot(xseq, b_history[-1] + w_history[-1]*xseq, color='green')
plt.xlabel('speeding', fontsize=16)
plt.ylabel('alcohol', fontsize=16)
plt.show()
我找到了答案。这是一个数据类型问题:
X = np.array(df.iloc[list(inds)]['Whisker-high']).reshape(-1,1)
在这里使用numpy是导致问题的原因。我只需要坚持张量,不要在这里切换到数组。与它不能与GradientTape
一起工作有关。