输出一类CNN神经网络的置信度/概率



我有一个问题。我想为我的预测获得信心/概率。我怎样才能获得信心呢?我看了一下如何在CNN中使用tensorflow实现置信度?. 但我不明白我是如何得到这个预测的。

class CNN_Text:
def __init__(self, x, y):
self.x =x
self.y = y
def forward(self):
filter_sizes = [1,2,3,5]
num_filters = 32

inp = Input(shape=(maxlen, ))
x = Embedding(embedding_matrix.shape[0], 300, weights=[embedding_matrix], trainable=False)(inp)
x = SpatialDropout1D(0.4)(x)
x = Reshape((maxlen, embed_size, 1))(x)
conv_0 = Conv2D(num_filters, kernel_size=(filter_sizes[0], embed_size), kernel_initializer='normal',
                      activation='elu')(x)
conv_1 = Conv2D(num_filters, kernel_size=(filter_sizes[1], embed_size), kernel_initializer='normal',
                      activation='elu')(x)
conv_2 = Conv2D(num_filters, kernel_size=(filter_sizes[2], embed_size), kernel_initializer='normal',
                      activation='elu')(x)
conv_3 = Conv2D(num_filters, kernel_size=(filter_sizes[3], embed_size), kernel_initializer='normal',
                      activation='elu')(x)
maxpool_0 = MaxPool2D(pool_size=(maxlen - filter_sizes[0] + 1, 1))(conv_0)
maxpool_1 = MaxPool2D(pool_size=(maxlen - filter_sizes[1] + 1, 1))(conv_1)
maxpool_2 = MaxPool2D(pool_size=(maxlen - filter_sizes[2] + 1, 1))(conv_2)
maxpool_3 = MaxPool2D(pool_size=(maxlen - filter_sizes[3] + 1, 1))(conv_3)
z = Concatenate(axis=1)([maxpool_0
, maxpool_1
, maxpool_2
, maxpool_3
]) 
# z = Dropout(0.3)(z)  
z = Flatten()(z)
z = Dropout(0.3)(z)
outp = Dense(53, activation="softmax")(z)
model = Model(inputs=inp, outputs=outp)
model.summary()
return model
p1 = CNN_Text(...)
model = p1.forward()
model.compile(...)
history = model.fit(...)
pred = model.predict(...)

如何预测类

x = x.lower()
x = remove_URL(x)
x = remove_punct(x)
x = remove_stopwords(x)

x = tokenizer.texts_to_sequences([x])
x = pad_sequences(x, maxlen=maxlen)
pred = model.predict(x)
pred = pred.argmax(axis=1)
pred = le.classes_[pred]
return pred[0]

Softmax激活函数将网络的输出归一化,为您提供给定样本中53个类别中每个类别的预测概率。

pred = pred.argmax(axis=1)

这一行给出了预测概率最高的节点的索引。

pred = pred.max(axis=1)

这将给你相应的概率(如果你之前没有用argmax覆盖pred)。

相关内容

  • 没有找到相关文章

最新更新