我正在尝试按照此处概述的逻辑生成我的keras模型的预测的置信区间:如何使用Keras计算预测不确定性?,以及此处:https://medium.com/hal24k-techblog/how-to-generate-neural-network-confidence-intervals-with-keras-e4c0b78ebbdf
我已经在 Keras github 问题页面上搜索了类似的问题,这表明我使用了错误的方法使用预定义配置实例化新模型,但是尽管将我的代码从Model.from_config()
更改为tf.keras.Sequential.from_config()
问题仍然存在。
代码如下:
# Model
def mlp_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(3, input_dim = 3,
kernel_initializer = 'glorot_uniform',
activation = 'elu'),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.001)),
tf.keras.layers.GaussianNoise(0.3),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.003)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.003)),
tf.keras.layers.Dense(160,
activation = 'relu',
kernel_regularizer = regularizers.l2(0.004)), #128 for unscaled
tf.keras.layers.Dense(1)
])
model.compile(optimizer = opt,
loss = root_mean_squared_error,
metrics=['mean_squared_error',
'mean_absolute_error',
root_mean_squared_error])
return model
# fit model
model = mlp_model()
model_history = model.fit(train_batch,
validation_data = test_batch,
epochs = 100,
shuffle = False,
callbacks = [early_stop,
model_checkpoint,
tensorboard_callback])
# Create dropout func
def create_dropout_predict_function(model, dropout):
# Load the config of the original model
conf = model.get_config()
# Add the specified dropout to all layers
for layer in conf['layers']:
# Dropout layers
if layer["class_name"]=="Dropout":
layer["config"]["rate"] = dropout
# Using Functional API
model_dropout = tf.keras.Sequential.from_config(conf)
model_dropout.set_weights(model.get_weights())
# Predict with dropout
predict_with_dropout = K.function(model_dropout.inputs+[K.learning_phase()], model_dropout.outputs)
return predict_with_dropout
# Create preds with dropout
dropout = 0.5
num_iter = 20
num_samples = len(forecast_ahead_df)
predict_with_dropout = create_dropout_predict_function(model, dropout)
predictions = np.zeros((num_samples, num_iter))
for i in range(num_iter):
predictions[:, i] = predict_with_dropout(forecast_batch)
非常感谢任何帮助,并非常感谢迄今为止提供帮助的所有帮助。
我使用的是 tensorflow 版本 2.0.0,没有 GPU,但希望这仍然可以解决问题。似乎您只提供网络输入数据 (X(,但没有响应数据 (Y(。因此,模型无需学习任何内容。我只在model.fit()
中添加了响应数据 Y。
我已经使用了您的代码并在此数据集上进行训练。以下代码在 tensorflow 2.0.0 中对我有用,请参阅注释以获取更改:
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import regularizers
import numpy as np
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# Changed the Input Dimenseion to input_dim = 8
def mlp_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(3, input_dim = 8,
kernel_initializer = 'glorot_uniform',
activation = 'elu'),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.001)),
tf.keras.layers.GaussianNoise(0.3),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.003)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(160,
activation = 'elu',
kernel_regularizer = regularizers.l2(0.003)),
tf.keras.layers.Dense(160,
activation = 'relu',
kernel_regularizer = regularizers.l2(0.004)), #128 for unscaled
tf.keras.layers.Dense(1)
])
# Removed unavailable functions and used the tensorflow available optimizer, loss and metrics
model.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics=['mean_squared_error',
'mean_absolute_error'])
return model
# fit model
model = mlp_model()
# Added the response variable Y
model_history = model.fit(X, Y,
epochs = 100)
输出-
2.0.0
Train on 768 samples
Epoch 1/100
768/768 [==============================] - 1s 1ms/sample - loss: 8.1430 - mean_squared_error: 105.6280 - mean_absolute_error: 8.0280
Epoch 2/100
768/768 [==============================] - 0s 90us/sample - loss: 11.3845 - mean_squared_error: 819.4331 - mean_absolute_error: 25.4116
Epoch 3/100
768/768 [==============================] - 0s 96us/sample - loss: 11.2765 - mean_squared_error: 1185.4595 - mean_absolute_error: 30.9923
Epoch 4/100
768/768 [==============================] - 0s 90us/sample - loss: 11.1805 - mean_squared_error: 1230.1932 - mean_absolute_error: 31.5183
Epoch 5/100
768/768 [==============================] - 0s 88us/sample - loss: 11.0962 - mean_squared_error: 1161.8567 - mean_absolute_error: 30.6245
Epoch 6/100
768/768 [==============================] - 0s 88us/sample - loss: 11.0217 - mean_squared_error: 1129.4146 - mean_absolute_error: 30.2203
Epoch 7/100
768/768 [==============================] - 0s 91us/sample - loss: 10.9550 - mean_squared_error: 1097.7333 - mean_absolute_error: 29.7042
Epoch 8/100
768/768 [==============================] - 0s 91us/sample - loss: 10.8947 - mean_squared_error: 1066.2788 - mean_absolute_error: 29.1740
Epoch 9/100
768/768 [==============================] - 0s 89us/sample - loss: 10.8396 - mean_squared_error: 1035.1652 - mean_absolute_error: 28.6741
Epoch 10/100
768/768 [==============================] - 0s 86us/sample - loss: 10.7892 - mean_squared_error: 987.1348 - mean_absolute_error: 28.0417
Epoch 11/100
768/768 [==============================] - 0s 90us/sample - loss: 10.7427 - mean_squared_error: 947.6363 - mean_absolute_error: 27.4314
Epoch 12/100
768/768 [==============================] - 0s 88us/sample - loss: 10.6997 - mean_squared_error: 891.5195 - mean_absolute_error: 26.5769
Epoch 13/100
768/768 [==============================] - 0s 88us/sample - loss: 10.6599 - mean_squared_error: 868.7969 - mean_absolute_error: 26.1633
Epoch 14/100
768/768 [==============================] - 0s 94us/sample - loss: 10.6230 - mean_squared_error: 818.0642 - mean_absolute_error: 25.2891
Epoch 15/100
768/768 [==============================] - 0s 92us/sample - loss: 10.5886 - mean_squared_error: 772.1309 - mean_absolute_error: 24.7540
Epoch 16/100
768/768 [==============================] - 0s 87us/sample - loss: 10.5567 - mean_squared_error: 745.1671 - mean_absolute_error: 24.1730
Epoch 17/100
768/768 [==============================] - 0s 85us/sample - loss: 10.5269 - mean_squared_error: 691.4273 - mean_absolute_error: 23.3115
Epoch 18/100
768/768 [==============================] - 0s 109us/sample - loss: 10.4991 - mean_squared_error: 651.1308 - mean_absolute_error: 22.6068
Epoch 19/100
768/768 [==============================] - 0s 87us/sample - loss: 10.4732 - mean_squared_error: 634.8137 - mean_absolute_error: 22.2203
Epoch 20/100
768/768 [==============================] - 0s 85us/sample - loss: 10.4489 - mean_squared_error: 600.7214 - mean_absolute_error: 21.5552
Epoch 21/100
768/768 [==============================] - 0s 87us/sample - loss: 10.4262 - mean_squared_error: 557.6275 - mean_absolute_error: 20.9041
Epoch 22/100
768/768 [==============================] - 0s 92us/sample - loss: 10.4050 - mean_squared_error: 533.1421 - mean_absolute_error: 20.3297
Epoch 23/100
768/768 [==============================] - 0s 89us/sample - loss: 10.3850 - mean_squared_error: 528.4443 - mean_absolute_error: 19.9225
Epoch 24/100
768/768 [==============================] - 0s 84us/sample - loss: 10.3663 - mean_squared_error: 488.9399 - mean_absolute_error: 19.3189
Epoch 25/100
768/768 [==============================] - 0s 89us/sample - loss: 10.3487 - mean_squared_error: 448.7306 - mean_absolute_error: 18.5765
Epoch 26/100
768/768 [==============================] - 0s 87us/sample - loss: 10.3322 - mean_squared_error: 425.1384 - mean_absolute_error: 18.0467
Epoch 27/100
768/768 [==============================] - 0s 89us/sample - loss: 10.3166 - mean_squared_error: 390.8499 - mean_absolute_error: 17.3933
Epoch 28/100
768/768 [==============================] - 0s 91us/sample - loss: 10.3019 - mean_squared_error: 373.9980 - mean_absolute_error: 16.8805
Epoch 29/100
768/768 [==============================] - 0s 88us/sample - loss: 10.2881 - mean_squared_error: 358.0975 - mean_absolute_error: 16.4838
Epoch 30/100
768/768 [==============================] - 0s 89us/sample - loss: 10.2750 - mean_squared_error: 330.0729 - mean_absolute_error: 15.8146
Epoch 31/100
768/768 [==============================] - 0s 94us/sample - loss: 10.2627 - mean_squared_error: 304.5206 - mean_absolute_error: 15.1873
Epoch 32/100
768/768 [==============================] - 0s 96us/sample - loss: 10.2510 - mean_squared_error: 284.8553 - mean_absolute_error: 14.7007
Epoch 33/100
768/768 [==============================] - 0s 85us/sample - loss: 10.2400 - mean_squared_error: 269.9680 - mean_absolute_error: 14.2868
Epoch 34/100
768/768 [==============================] - 0s 88us/sample - loss: 10.2295 - mean_squared_error: 258.1902 - mean_absolute_error: 13.8228
Epoch 35/100
768/768 [==============================] - 0s 88us/sample - loss: 10.2196 - mean_squared_error: 233.8299 - mean_absolute_error: 13.2536
Epoch 36/100
768/768 [==============================] - 0s 90us/sample - loss: 10.2102 - mean_squared_error: 226.0575 - mean_absolute_error: 12.8807
Epoch 37/100
768/768 [==============================] - 0s 90us/sample - loss: 10.2013 - mean_squared_error: 208.6756 - mean_absolute_error: 12.4301
Epoch 38/100
768/768 [==============================] - 0s 85us/sample - loss: 10.1928 - mean_squared_error: 190.1977 - mean_absolute_error: 11.8795
Epoch 39/100
768/768 [==============================] - 0s 92us/sample - loss: 10.1847 - mean_squared_error: 179.3409 - mean_absolute_error: 11.5069
Epoch 40/100
768/768 [==============================] - 0s 89us/sample - loss: 10.1770 - mean_squared_error: 164.0748 - mean_absolute_error: 11.0245
Epoch 41/100
768/768 [==============================] - 0s 89us/sample - loss: 10.1697 - mean_squared_error: 154.7774 - mean_absolute_error: 10.5896
Epoch 42/100
768/768 [==============================] - 0s 88us/sample - loss: 10.1627 - mean_squared_error: 138.8940 - mean_absolute_error: 10.1659
Epoch 43/100
768/768 [==============================] - 0s 88us/sample - loss: 10.1560 - mean_squared_error: 130.8000 - mean_absolute_error: 9.8535
Epoch 44/100
768/768 [==============================] - 0s 89us/sample - loss: 10.1496 - mean_squared_error: 120.8661 - mean_absolute_error: 9.4153
Epoch 45/100
768/768 [==============================] - 0s 84us/sample - loss: 10.1435 - mean_squared_error: 109.2852 - mean_absolute_error: 9.0433
Epoch 46/100
768/768 [==============================] - 0s 88us/sample - loss: 10.1377 - mean_squared_error: 103.8257 - mean_absolute_error: 8.6843
Epoch 47/100
768/768 [==============================] - 0s 93us/sample - loss: 10.1322 - mean_squared_error: 96.0696 - mean_absolute_error: 8.3516
Epoch 48/100
768/768 [==============================] - 0s 95us/sample - loss: 10.1268 - mean_squared_error: 87.2735 - mean_absolute_error: 7.9551
Epoch 49/100
768/768 [==============================] - 0s 84us/sample - loss: 10.1217 - mean_squared_error: 78.5033 - mean_absolute_error: 7.5730
Epoch 50/100
768/768 [==============================] - 0s 89us/sample - loss: 10.1168 - mean_squared_error: 73.8310 - mean_absolute_error: 7.2905
Epoch 51/100
768/768 [==============================] - 0s 96us/sample - loss: 10.1121 - mean_squared_error: 65.5390 - mean_absolute_error: 6.9333
Epoch 52/100
768/768 [==============================] - 0s 94us/sample - loss: 10.1076 - mean_squared_error: 62.0906 - mean_absolute_error: 6.6996
Epoch 53/100
768/768 [==============================] - 0s 93us/sample - loss: 10.1033 - mean_squared_error: 54.6800 - mean_absolute_error: 6.3188
Epoch 54/100
768/768 [==============================] - 0s 90us/sample - loss: 10.0991 - mean_squared_error: 49.5616 - mean_absolute_error: 6.0299
Epoch 55/100
768/768 [==============================] - 0s 91us/sample - loss: 10.0951 - mean_squared_error: 46.1167 - mean_absolute_error: 5.7707
Epoch 56/100
768/768 [==============================] - 0s 92us/sample - loss: 10.0913 - mean_squared_error: 41.3520 - mean_absolute_error: 5.4638
Epoch 57/100
768/768 [==============================] - 0s 98us/sample - loss: 10.0876 - mean_squared_error: 37.7623 - mean_absolute_error: 5.2175
Epoch 58/100
768/768 [==============================] - 0s 92us/sample - loss: 10.0840 - mean_squared_error: 35.0457 - mean_absolute_error: 4.9933
Epoch 59/100
768/768 [==============================] - 0s 85us/sample - loss: 10.0806 - mean_squared_error: 31.3319 - mean_absolute_error: 4.6894
Epoch 60/100
768/768 [==============================] - 0s 90us/sample - loss: 10.0773 - mean_squared_error: 28.8361 - mean_absolute_error: 4.5389
Epoch 61/100
768/768 [==============================] - 0s 92us/sample - loss: 10.0741 - mean_squared_error: 25.4872 - mean_absolute_error: 4.2417
Epoch 62/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0710 - mean_squared_error: 23.8177 - mean_absolute_error: 4.1007
Epoch 63/100
768/768 [==============================] - 0s 84us/sample - loss: 7.0914 - mean_squared_error: 101.6586 - mean_absolute_error: 7.4064
Epoch 64/100
768/768 [==============================] - 0s 87us/sample - loss: 5.5155 - mean_squared_error: 457.9014 - mean_absolute_error: 18.0251
Epoch 65/100
768/768 [==============================] - 0s 88us/sample - loss: 5.5088 - mean_squared_error: 392.7791 - mean_absolute_error: 16.7004
Epoch 66/100
768/768 [==============================] - 0s 98us/sample - loss: 5.5000 - mean_squared_error: 272.5681 - mean_absolute_error: 13.7357
Epoch 67/100
768/768 [==============================] - 0s 86us/sample - loss: 5.4928 - mean_squared_error: 188.3301 - mean_absolute_error: 11.4254
Epoch 68/100
768/768 [==============================] - 0s 84us/sample - loss: 5.4869 - mean_squared_error: 144.7060 - mean_absolute_error: 10.0285
Epoch 69/100
768/768 [==============================] - 0s 83us/sample - loss: 5.4821 - mean_squared_error: 160.2795 - mean_absolute_error: 10.7119
Epoch 70/100
768/768 [==============================] - 0s 94us/sample - loss: 5.4772 - mean_squared_error: 114.1624 - mean_absolute_error: 9.0611
Epoch 71/100
768/768 [==============================] - 0s 91us/sample - loss: 5.4729 - mean_squared_error: 81.6616 - mean_absolute_error: 7.5558
Epoch 72/100
768/768 [==============================] - 0s 88us/sample - loss: 5.4690 - mean_squared_error: 58.3508 - mean_absolute_error: 6.3955
Epoch 73/100
768/768 [==============================] - 0s 87us/sample - loss: 5.4654 - mean_squared_error: 44.2179 - mean_absolute_error: 5.5208
Epoch 74/100
768/768 [==============================] - 0s 91us/sample - loss: 5.4621 - mean_squared_error: 33.7123 - mean_absolute_error: 4.7988
Epoch 75/100
768/768 [==============================] - 0s 97us/sample - loss: 5.4590 - mean_squared_error: 26.4150 - mean_absolute_error: 4.2291
Epoch 76/100
768/768 [==============================] - 0s 86us/sample - loss: 5.4561 - mean_squared_error: 21.0693 - mean_absolute_error: 3.7517
Epoch 77/100
768/768 [==============================] - 0s 82us/sample - loss: 5.4533 - mean_squared_error: 16.6050 - mean_absolute_error: 3.3334
Epoch 78/100
768/768 [==============================] - 0s 88us/sample - loss: 5.4507 - mean_squared_error: 13.3761 - mean_absolute_error: 2.9838
Epoch 79/100
768/768 [==============================] - 0s 94us/sample - loss: 5.4482 - mean_squared_error: 10.6811 - mean_absolute_error: 2.6628
Epoch 80/100
768/768 [==============================] - 0s 90us/sample - loss: 5.4459 - mean_squared_error: 9.5541 - mean_absolute_error: 2.4669
Epoch 81/100
768/768 [==============================] - 0s 88us/sample - loss: 5.4438 - mean_squared_error: 11.0352 - mean_absolute_error: 2.7847
Epoch 82/100
768/768 [==============================] - 0s 83us/sample - loss: 5.4416 - mean_squared_error: 8.8724 - mean_absolute_error: 2.4634
Epoch 83/100
768/768 [==============================] - 0s 96us/sample - loss: 5.4395 - mean_squared_error: 7.0467 - mean_absolute_error: 2.1836
Epoch 84/100
768/768 [==============================] - 0s 86us/sample - loss: 8.0375 - mean_squared_error: 36.5165 - mean_absolute_error: 4.8413
Epoch 85/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0764 - mean_squared_error: 126.9524 - mean_absolute_error: 10.2454
Epoch 86/100
768/768 [==============================] - 0s 87us/sample - loss: 10.0705 - mean_squared_error: 121.3718 - mean_absolute_error: 9.9911
Epoch 87/100
768/768 [==============================] - 0s 88us/sample - loss: 10.0631 - mean_squared_error: 96.0288 - mean_absolute_error: 8.9118
Epoch 88/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0572 - mean_squared_error: 74.3836 - mean_absolute_error: 7.9054
Epoch 89/100
768/768 [==============================] - 0s 88us/sample - loss: 10.0524 - mean_squared_error: 62.7223 - mean_absolute_error: 7.1927
Epoch 90/100
768/768 [==============================] - 0s 90us/sample - loss: 10.0484 - mean_squared_error: 48.8116 - mean_absolute_error: 6.4094
Epoch 91/100
768/768 [==============================] - 0s 93us/sample - loss: 10.0448 - mean_squared_error: 39.1147 - mean_absolute_error: 5.7639
Epoch 92/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0417 - mean_squared_error: 32.7087 - mean_absolute_error: 5.2643
Epoch 93/100
768/768 [==============================] - 0s 85us/sample - loss: 10.0389 - mean_squared_error: 26.4908 - mean_absolute_error: 4.7354
Epoch 94/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0363 - mean_squared_error: 21.6876 - mean_absolute_error: 4.2937
Epoch 95/100
768/768 [==============================] - 0s 89us/sample - loss: 10.0340 - mean_squared_error: 17.6745 - mean_absolute_error: 3.8886
Epoch 96/100
768/768 [==============================] - 0s 86us/sample - loss: 10.0318 - mean_squared_error: 14.5234 - mean_absolute_error: 3.5385
Epoch 97/100
768/768 [==============================] - 0s 93us/sample - loss: 10.0298 - mean_squared_error: 12.3609 - mean_absolute_error: 3.2538
Epoch 98/100
768/768 [==============================] - 0s 91us/sample - loss: 10.0279 - mean_squared_error: 10.0023 - mean_absolute_error: 2.9435
Epoch 99/100
768/768 [==============================] - 0s 88us/sample - loss: 10.0261 - mean_squared_error: 8.6058 - mean_absolute_error: 2.7176
Epoch 100/100
768/768 [==============================] - 0s 88us/sample - loss: 10.0244 - mean_squared_error: 7.0987 - mean_absolute_error: 2.4628
希望这能回答你的问题。快乐学习。