在 Keras 多标签回归中屏蔽 NA



我正在尝试使用 Keras 构建多标签回归模型。标签具有许多 NA 值,即并非所有实例都针对所有标签进行了测试。这是我的代码示例:

import numpy as np
import pandas as pd
from sklearn.datasets import make_multilabel_classification
X, _ = make_multilabel_classification(n_samples = 1000, sparse = True, n_features = 40, 
                                      return_indicator = 'sparse', allow_unlabeled = False)
y = pd.DataFrame(np.random.randint(0, 100, (1000, 10)))
na_ = np.random.choice([True, False], size=y.shape)
na_[na_.all(1),-1] = 0
y = y.mask(na_)
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
X_train,  X_test, y_train, y_test = train_test_split(X, y, random_state=42)
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.callbacks import ReduceLROnPlateau
from keras import regularizers
from keras.optimizers import RMSprop, Adam, SGD
sgd = SGD(lr=0.01, momentum=0.9, nesterov=True)
model = Sequential()
model.add(Dense(100, input_dim=40))
model.add(Activation('relu'))
model.add(Dense(10))
model.compile(loss='mean_squared_error', optimizer=sgd, metrics=['mae'])
hist = model.fit(X_train, y_train, epochs=500, verbose=1, validation_split=0.2)
scores = model.evaluate(X_test, y_test)

要预测的多标签 (y( 如下所示:

0   NaN     NaN     4.0     NaN     NaN     NaN     NaN     35.0    NaN     98.0
1   NaN     NaN     70.0    17.0    NaN     NaN     4.0     69.0    33.0    NaN
2   14.0    NaN     NaN     65.0    NaN     NaN     NaN     50.0    64.0    55.0
3   78.0    NaN     2.0     NaN     44.0    79.0    67.0    43.0    3.0     64.0
4   NaN     54.0    NaN     NaN     NaN     67.0    18.0    39.0    3.0     41.0

我需要通过引入一个掩码变量来改变损失函数,该变量会屏蔽掉所有 NaN 标签,但我无法实现这一点。请协助!

您可以使用

tf.where()函数屏蔽标签,如下所示:

import tensorflow as tf
from keras import backend as K
def mse_nan(y_true, y_pred):
    masked_true = tf.where(tf.is_nan(y_true), tf.zeros_like(y_true), y_true)
    masked_pred = tf.where(tf.is_nan(y_true), tf.zeros_like(y_true), y_pred)
    return K.mean(K.square(masked_pred - masked_true), axis=-1)

最新更新