错误的是,在我的代码中,错误随着梯度下降的每次迭代而不断增加



下面的代码读取csv(Andrew NG ML课程ex1多元线性回归练习数据文件),然后尝试使用学习率alpha=0.01将线性模型拟合到数据集。梯度下降是对参数(θ矢量)进行400次递减(问题陈述中给出了α和num_of_editions值)。我尝试了一种矢量化实现来获得参数的最佳值,但下降并不收敛——误差不断增加。

# Imports

```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
```
# Model Preparation
## Gradient descent

```python
def gradient_descent(m, theta, alpha, num_of_iterations, X, Y):
#     print(m, theta, alpha, num_of_iterations)
for i in range(num_of_iterations):
htheta_vector = np.dot(X,theta)
#         print(X.shape, theta.shape, htheta_vector.shape)
error_vector = htheta_vector - Y
gradient_vector = (1/m) * (np.dot(X.T, error_vector)) # each element in gradient_vector corresponds to each theta
theta = theta - alpha * gradient_vector
return theta
```
# Main

```python
def main():
df = pd.read_csv('data2.csv', header = None) #loading data
data = df.values # converting dataframe to numpy array
X = data[:, 0:2]
#     print(X.shape)
Y = data[:, -1]
m = (X.shape)[0] # number of training examples
Y = Y.reshape(m, 1)
ones = np.ones(shape = (m,1))
X_with_bias = np.concatenate([ones, X], axis = 1)
theta = np.zeros(shape = (3,1)) # two features, so three parameters
alpha = 0.001
num_of_iterations = 400
theta = gradient_descent(m, theta, alpha, num_of_iterations, X_with_bias, Y) # calling gradient descent
#     print('Parameters learned: ' + str(theta))
if __name__ == '__main__':
main()
```

错误:

/home/krish-thorcode/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:8: RuntimeWarning: invalid value encountered in subtract

不同迭代的错误值:

迭代1[-399900.][-32900.][-369000.][-232000.][-539900.][-29900.][-31490.][-198999.][-212000.][-242500.][-239999.][-34700.][-32999.][699900.][-25900.][449900.][-299900.][-199900.][499998.][599000.][252900.][-255000.][-242900.][-25900.][-573900.][-24900.][-46450.][-46900.][-475000.][-29900.][-34900.][-16900.][-31490.][-579900.][-285900.][-24900.][-22900.][-345000.][549000.][-28700.][-368500.][-32900.][-314000.][-299000.][-17900.][-29900.][-239500.]]

迭代2[[1.60749981e+09][1.22240841e+09][1.83373661e+09][1.08189071e+09][2.29209231e+09][1.51666004e+09][11.17198560e+09][1.0303113e+09][1.05440030e+09][11.1448964e+09]+1.48233053e+09][1.52807496e+09][1.4402895e+09][3.42143452e+09][9.68760976e+08][1.75723592e+09][1.00845873e+09][9.44366284e+08][1.99332644e+09][2.31572369e+09][1.35010833e+09]+1.44257442e+09][12.25555224e+09][1.49912323e+09][2.97220331e+09][8.40383843e+08][111375611e+09][1.92992696e+09][1.68078878e+09][2.01492327e+09]+1.40503327e+09][7.64040689e+08][1.55867654e+09][2.39674784e+09][1.38370165e+09][1.09792232e+09][9.46628911e+08][1.62895368e+09][3.22059730e+09][1.65193796e+09][1.27127807e+09]+1.70997383e+09][19.96141565e+09][9.16755655e+08][6.50928858e+08][1.41502023e+09][9.19107783e+08]]

迭代3[[-7.42664624e+12][-5.64764378e+12][-8.47145714e+12][-4.99816153e+12][-1.05893224e+13][-7.00660901e+12][-5.41467917e+12][-5.03699402e+12][-4.87109500e+12][-5.27348843e+12][-6.84776945e+12][-7.05955046e+12][-6.67127611e+12][-1.58063228e+13][-4.47576119e+12][-8.11848565e+12][-4.65930400e+12][-4.36280860e+12][-9.20918360e+12][-1.06987452e+13][-6.23711474e+12][-6.66421140e+12][-5.66176276e+12][-6.92542434e+12][-1.37308096e+13][-3.88276038e+12][-5.14641706e+12][-8.91620784e+12][-7.76550392e[12][-9.30801176e+12][-6.49125293e+12][-3.52977344e+12][-7.20074619e+12][-1.10728954e+13][-6.39242960e+12][-5.07229174e+12][-4.37339793e+12][-7.52548475e+12][-1.48779889e+13][-7.63137769e+12][-5.87354379e+12][-7.89963490e+12][-9.06093321e+12][-4.23573710e+12]]-3.00737309e+12][-6.53715005e+12][-4.24632634e+12]]

迭代4[[3.43099835e+16][2.60912608e+16][3.91368523e+16][2.30907512e+16][4.89210695e+16][3.23694753e+16][2.50149995e+16][2.32701516e+16][2.25037231e+16][2.43627199e+16][3.16356608e+16][3.26140566e+16][3.08202877e+16][7.30228235e+16][2.06773403e+16][3.75061770e+16][2.15252802e+16][2.01555166e+16][4.25540367e+16][4.94265862e+16][2.88145280e+16][3.07876502e+16]/2.61564888e+16][3.199944145e+16][6.34342666e+16][1.79377661e+16][2.37756683e+16][4.11915330e+16][3.58754545e+16][4.3001608e+16][2.99886077e+16][1.63070200e+16][3.32663597e+16][5.11551035e+16][2.95320591e+16][2.34332215e+16][2.02044376e+16][3.47666027e+16][6.87340617e+16][3.52558124e+16][2.71348846e+16][3.64951201e+16][4.18601431e+16][1.95684650e+16][1.38936092e+16][3.02006457e+16][1.96173860e+16]]

迭代5[-1.58506940e+20][-1.20537683e+20][-180806345e+20][-1.06675782e+20][-2.26007951e+20]]-1.49542086e+20][-1.15565519e+20][-1.07504585e+20][-103963801e+20][-1.12552086e+20][-1.46151974e+20][-1.50672014e+20][-1.42385073e+20][-3.37354413e+20][-9.55261885e+19][-1.73272871e+20][-9.94435428e+19][-9.31154420e+19][-1.96551642e+20][-2.28343362e+20][-133118767e+20][-1.42234293e+20][-1.20839027e+20]]-1.47809362e+20][-2.93056729e+20][-8.28697695e+19][-1.09839999e+20][-1.90298660e+20][-1.65739180e+20][-1.98660937e+20][-1.38542837e+20][-7.53359691e+19][-1.53685556e+20][-2.36328850e+20][-136433652e+20][-1.08257943e+20][-9.33414495e+19][-1.60616452e+20][-3.17540981e+20][-1.62876527e+20][-1.25359067e+20][-1.68601941e+20][-1.93387537e+20][-9.04033523e+19][-6.41863754e+19][-1.39522421e+20][-9.06293597e+19]

迭代83[[-1.09904300e+306][-8.35774743e+305][-1.25366087e+306][-7.39660179e+305][-1.56707622e+306][-1.03688320e+306][-8.01299137e+305][-7.45406868e+305][-7.20856058e+305][-7.80404831e+305][-1.01337710e+306][-1.04471781e+306][-9.87258464e+305][-2.33912159e+306][-6.62352000e+305][-1.20142586e+306][-6.89513844e+305][-6.45636555e+305][-1.36283437e+306][-1.58326931e+306][-9.23008472e+305][-9.86212994e+305][-8.37864174e+305][-1.02486897e+306][-2.03197378e+306][-5.74595914e+305][-7.61599955e+305][-1.31947793e+306][-1.144918934e+306][-1.37745963e+306][-9.60617469e+305][-5.22358639e+305][-1.06561287e+306][-1.63863846e+306][-9.45992963e+305][-7.50630445e+305][-6.47203628e+305][-1.11366977e+306][-2.20174077e+306][-1.12934050e+306][-8.69204879e+305][-1.16903893e+306][-1.34089535e+306][-6.26831680e+305][-4.45050460e+305][-9.67409627e+305][-6.283987753e+305]]

迭代84[inf][inf][inf][info][inf][inf][inf][interf][inf]
[inf][inf][if][inf][inf][info][inf][inf][inf][inf][inf][inf][inf]]

请尝试特性规范化来解决此问题。只是特征值是大数字,并且当值大时,成本函数(平方误差)以快速的速度增加。通常,当您试图最小化非线性成本函数时,请执行平均归一化和特征缩放。

进行特征规范化。这是你的数据集,X的第一个维度是千,第二个维度是十,Y是十万。使用sklearn.preprocessing.scale使所有列的数据和目标为[0,1],或者您可以使用脏规范化:

X[:,0] = X[:,0] / np.max( X[:,0])
X[:,1] = X[:,1] / np.max( X[:,1])
Y = Y / np.max(Y)

我用这些规范器重新编译了你的代码。Theta收敛于[ 0.81705857], [ 0.98398577], [ 0.98398577]

尝试提供数据文件或Panda数据框架摘要的链接,以供将来提问。

最新更新