我想重建一个我首先用scikit-learn的MLPRegressor和tflearn实现的MLP。
sklearn.neural_network。MLPR出口器实现:
train_data = pd.read_csv('train_data.csv', delimiter = ';', decimal = ',', header = 0)
test_data = pd.read_csv('test_data.csv', delimiter = ';', decimal = ',', header = 0)
X_train = np.array(train_data.drop(['output'], 1))
X_scaler = StandardScaler()
X_scaler.fit(X_train)
X_train = X_scaler.transform(X_train)
Y_train = np.array(train_data['output'])
clf = MLPRegressor(activation = 'tanh', solver='lbfgs', alpha=0.0001, hidden_layer_sizes=(3))
clf.fit(X_train, Y_train)
prediction = clf.predict(X_train)
该模型有效,我得到了0.85
的准确性。现在我想用tflearn构建一个类似的MLP。我从以下代码开始:
train_data = pd.read_csv('train_data.csv', delimiter = ';', decimal = ',', header = 0)
test_data = pd.read_csv('test_data.csv', delimiter = ';', decimal = ',', header = 0)
X_train = np.array(train_data.drop(['output'], 1))
X_scaler = StandardScaler()
X_scaler.fit(X_train)
X_train = X_scaler.transform(X_train)
Y_train = np.array(train_data['output'])
Y_scaler = StandardScaler()
Y_scaler.fit(Y_train)
Y_train = Y_scaler.transform(Y_train.reshape((-1,1)))
net = tfl.input_data(shape=[None, 6])
net = tfl.fully_connected(net, 3, activation='tanh')
net = tfl.fully_connected(net, 1, activation='sigmoid')
net = tfl.regression(net, optimizer='sgd', loss='mean_square', learning_rate=3.)
clf = tfl.DNN(net)
clf.fit(X_train, Y_train, n_epoch=200, show_metric=True)
prediction = clf.predict(X_train)
在某些时候,我肯定以错误的方式配置了一些东西,因为预测太差了。Y_train的范围在 20
到 88
之间,预测显示数字约为 0.005
。在 tflearn 文档中,我刚刚找到了分类示例。
更新 1:
我意识到回归层默认使用'categorical_crossentropy'
作为用于分类问题的损失函数。所以我选择了'mean_square'
。我也试图使Y_train
正常化.预测甚至仍然与Y_train
的范围不匹配.有什么想法吗?
最后更新:
看看接受的答案。
一步应该是不缩放输出。我也在研究回归问题,我只缩放输入,它在某些神经网络上工作正常。虽然如果我使用 tflearn,我会得到错误的预测。
我犯了几个非常愚蠢的错误。
首先,我将输出调用到间隔0
到1
但在输出层中使用激活函数tanh
,它将值从-1
传递到1
。所以我不得不使用一个激活函数来输出 0
和 1
之间的值(例如 sigmoid
( 或未应用任何缩放比例的linear
。
其次,也是最重要的一点,对于我的数据,我为learning rate
和n_epoch
选择了一个非常糟糕的组合。我没有指定任何学习率,默认的学习率是0.1
,我认为。无论如何,它太小了(我最终使用了3.0
(。同时纪元计数(10
(也太小了,200
工作正常。
我还明确选择了sgd
作为optimizer
(默认:adam
(,结果证明效果更好。
我更新了问题中的代码。