我是Keras的新手。当我完成鸢尾分类教程时,我只是对此感到困惑,因为我们编码了这3种鸢尾花,例如,one-hot编码。我们应该得到3个正交向量,对吧?
setosa [1 0 0]
versicolor [0 1 0]
virginica [0 0 1]
my model is same as tutorial:
http://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/
和我的问题是,虽然我得到了结果:
Baseline: 95.33% (4.27%)
但是当我调用训练好的深度网络模型时:
prediction = baseline_model().predict(X)
其中X是训练网络时的原始输入
我得到了一个非常奇怪的预测,这样:
print prediction
0,0,0
0,0,0
0,0,0
0,0,0
都是零向量,我应该得到一些单热编码的结果,对吧?来确定花应该属于哪个类。
那么我如何利用我训练过的Keras模型,而我输入相同的输入X并获得分类结果来绘制图??
你需要训练你的网络,只有这样你才能使用它进行预测。使用教程中的符号,您可以这样做:
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0)
X_train, X_test, Y_train, Y_test = train_test_split(X, dummy_y, test_size=0.33, random_state=seed)
estimator.fit(X_train, Y_train)
predictions = estimator.predict(X)