如何获得多标签分类问题的样本权重和类权重



我正试图为多标签分类问题构建一个神经网络。

情况

在一个输入图像中,可能有多个输出类(并且它们不是互斥的(。共有6个班。

示例

图像1中有类1、类2和类5。因此,输出如下[1,1,0,0,1,0]。

数据不平衡问题

基于该图像类型中出现的类的组合,我总共有32种独特类型的图像。因此,一个类型可以包含所有类(由[1,1,1,1,2,1]表示(,而另一个类型可能不包含任何类(由[0,0,0,0,0,0]表示(。

与其他图像(如不存在类的图像(相比,一些图像非常罕见(如同时包含类1、类3、类4和类6的图像(。从下面给出的数据中应该可以清楚地看到这一点。

Image Type         : No. of samples of that image type
[1, 0, 1, 1, 0, 1] : 1
[1, 0, 1, 0, 1, 1] : 2
[1, 1, 1, 0, 1, 1] : 2
[1, 1, 1, 1, 1, 1] : 2
[1, 0, 1, 1, 1, 1] : 3
[1, 1, 1, 1, 0, 1] : 3
[1, 0, 1, 0, 0, 1] : 3
[1, 1, 1, 0, 0, 1] : 4
[1, 1, 0, 1, 1, 1] : 4
[1, 1, 0, 1, 0, 1] : 7
[1, 1, 0, 0, 1, 1] : 7
[1, 0, 0, 1, 1, 1] : 8
[1, 0, 0, 1, 0, 1] : 16
[1, 1, 0, 0, 0, 1] : 21
[1, 0, 0, 0, 1, 1] : 28
[0, 1, 1, 0, 1, 1] : 53
[0, 1, 1, 1, 1, 1] : 63
[0, 0, 1, 1, 1, 1] : 70
[0, 0, 1, 0, 1, 1] : 78
[1, 0, 0, 0, 0, 1] : 122
[0, 1, 1, 1, 0, 1] : 141
[0, 1, 0, 1, 1, 1] : 159
[0, 1, 0, 0, 1, 1] : 239
[0, 0, 1, 1, 0, 1] : 265
[0, 1, 0, 1, 0, 1] : 283
[0, 0, 0, 1, 1, 1] : 366
[0, 1, 1, 0, 0, 1] : 491
[0, 0, 1, 0, 0, 1] : 712
[0, 1, 0, 0, 0, 1] : 1128
[0, 0, 0, 1, 0, 1] : 1183
[0, 0, 0, 0, 1, 1] : 2319
[0, 0, 0, 0, 0, 0] : 46431
Total no. of samples = 54,214 sample images

另一个问题是阶级的不平衡表现。由于总共有54214个图像样本,每个样本有6个类别。我们把这两个值相乘得到一个总数。54214*6=352284

下面给出的数据清楚地表明,类1(存在(是代表性最小的类。此外,我们可以看到,与阳性(1(相比,阴性(0(更多。

Absent(0) Present(1) Total(0 + 1)
Class 1  53981     233        54214
Class 2  52321     1893       54214
Class 3  51640     2574       54214
Class 4  51607     2607       54214
Class 5  50811     3403       54214
Class 6  46431     7783       54214
Total :  306791 +  18493   =  325284 

我正在使用Keras,我知道我们可以在训练模型时通过sample_weightclass_weight。由于这是一个多标签分类问题,我在最后一层使用sigmoid激活和二进制交叉熵损失。

问题

  1. 我应该如何计算sample_weight,以便更有力地表示稀有样本(如[1,0,1,1,0,1]类型的样本(?

  2. 在这种情况下,我应该如何计算class_weight,以便解决负数(0(多于正数(1(的问题?

  3. [可选/不太重要]如果我想比其他五个班更严厉地惩罚第六班(因为第六班是最重要的(,我该怎么办?

我知道可以使用scikit learn的compute_sample_weight和compute_class_weight之类的东西来计算它。

如果有人能提供一个解决方案并用数学方法解释它,那将非常有帮助。此外,如果我理解错误,请纠正我。

我相信有很多方法可以解决这个问题,但我的想法如下:

  • 有一个单独的模型来预测图像是否是任何类的一部分。这应该很简单
  • 如果图像是步骤1中任何类别的一部分,则预测图像是哪个特定类别的元素

划分问题可能是有益的,因为您可以在第一个模型上训练整个数据集,然后在第二步中进行选择性采样,以解决数据失衡问题。在第一步中,您可以最大限度地避免丢失任何信息,在第二步中,通过简化问题和解决数据失衡来帮助网络。

在第二步中,您可以选择:

  1. 六个独立的二进制分类模型,代表具有选择性采样的每个类别
  2. 一个具有选择性抽样的多标签分类模型

在第一个建议中,您将为每个模型选择样本,以便每个模型中标签0和1之间的比例为50/50。例如,对于类1,您将有233个该类的图像元素和233个非该类元素的其他任意选择的图像。通过这种方式,您没有数据失衡。如果您的数据失衡实际上是由于采样偏差造成的,则此选项是有意义的。

在第二个建议中,您将只使用任何类的元素数据进行训练。通过这种方式,您确实存在一些数据不平衡,但仍比最初少很多。如果需要,可以通过对特定类使用数据增强来更频繁地使用该类的图像进行训练,从而应用更复杂的选择性采样。在这种情况下,数据失衡将进一步减少。

然而,在现实世界中,一些数据不平衡实际上是有代表性的。这就是我个人赞同第二个建议的原因。

最新更新