DeepNetts1.3在使用ADAM优化器的任何网络上的setEarlyStoping和writeToFile上的序列



这是我的设置。我尝试每5个时期检查一次是否过拟合并使用ADAM优化器。并试图将经过训练的网络保存到以后使用。似乎无论设置如何,我都可以使用ADAM优化器,但由于无法序列化,它会失败。我尝试了Momentum反向传播,它有效,但我想使用ADAM。

System.out.println("LOAD TRAINING DATA.");
DataSet<MLDataItem> trainingSet = new TabularDataSet(d.inputs[0].length, d.outputs[0].length);
trainingSet.setColumnNames(d.headers);
for (int i = 0; i < d.inputs.length; i++) {
trainingSet.add(new TabularDataSet.Item(d.inputs[i], d.outputs[i]));
}
System.out.println("NORMALIZING TRAINING DATA.");
DataSets.normalizeMax(trainingSet);
//trainingSet.getColumnNames();
System.out.println("CREATING NETWORK.");
neuralNet = FeedForwardNetwork.builder()
.addInputLayer(d.inputs[0].length)
.addFullyConnectedLayer(d.inputs[0].length, ActivationType.SIGMOID)
.addFullyConnectedLayer((int) d.inputs[0].length / 2, ActivationType.SIGMOID)
.addFullyConnectedLayer((int) d.inputs[0].length / 4, ActivationType.SIGMOID)
.addOutputLayer(d.outputs[0].length, ActivationType.SIGMOID)
.lossFunction(LossType.MEAN_SQUARED_ERROR)
.randomSeed(123)
.build();
System.out.println("TRAINING CONFIGURATIONS.");
neuralNet.setLabel("TRAINING DATA");
BackpropagationTrainer trainer = neuralNet.getTrainer();
trainer.setBatchMode(false);
trainer.setEarlyStopping(true);
trainer.setEarlyStoppingMinLossChange(0.00000001F);
trainer.setEarlyStoppingPatience(5);
trainer.setLearningRate(0.001F);
trainer.setMaxEpochs(100);
trainer.setMaxError(0.0001F);
trainer.setMomentum(0F);
trainer.setTrainingSnapshots(true);
trainer.setOptimizer(OptimizerType.ADAM);
System.out.println("TRAINING...");
neuralNet.train(trainingSet);
neuralNetFile = "neuralNetwork_" + timeStamp + ".dnet";
System.out.println("SAVING NETWORK INTO " + neuralNetFile);
FileIO.writeToFile(neuralNet, neuralNetFile);
System.out.println("DONE!!!");

早期停止上的NotSerializable异常

at DeepNettsLearning.train(DeepNettsLearning.java:100) [classes/:?]
at DeepNettsLearning.main(DeepNettsLearning.java:36) [classes/:?]
Epoch:20, Time:2ms, TrainError:0.053135615, TrainErrorChange:-0.0012555942, TrainAccuracy: 0.9285714
Epoch:21, Time:4ms, TrainError:0.0519301, TrainErrorChange:-0.0012055151, TrainAccuracy: 0.9285714
Epoch:22, Time:4ms, TrainError:0.050770074, TrainErrorChange:-0.0011600256, TrainAccuracy: 0.9285714
Epoch:23, Time:3ms, TrainError:0.049651828, TrainErrorChange:-0.0011182465, TrainAccuracy: 0.9285714
Epoch:24, Time:2ms, TrainError:0.048572194, TrainErrorChange:-0.0010796338, TrainAccuracy: 0.9285714
Catching
java.io.NotSerializableException: deepnetts.net.train.opt.AdamOptimizer
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1193) ~[?:?]
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579) ~[?:?]
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1536) ~[?:?]
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1444) ~[?:?]
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187) ~[?:?]
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579) ~[?:?]
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1536) ~[?:?]
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1444) ~[?:?]
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187) ~[?:?]
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579) ~[?:?]
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1536) ~[?:?]
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1444) ~[?:?]
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187) ~[?:?]
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579) ~[?:?]
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1536) ~[?:?]

以及保存到文件时

Total Training Time: 595ms
------------------------------------------------------------------------
SAVING NETWORK INTO neuralNetwork_2021.03.19.17.31.37.dnet
Exception in thread "main" java.io.NotSerializableException: deepnetts.net.train.opt.AdamOptimizer
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1193)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1536)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1444)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1579)

有人能帮我吗?提前感谢

看起来DeepNetts的代码中有一个bug。我已经检查了他们的GitHub,没有adam优化器,所以我认为你使用的是专业版

你能请他们支持这个问题吗?

相关内容

  • 没有找到相关文章

最新更新