Tensorflow - 有没有办法按标签分隔tf.data.Dataset?



我知道我可以在将数据加载到我的网络之前通过它们的标签来分隔我的数据。假设有 3 个类,标签为 0,1,2。我可以通过以下方式做到这一点:

dataset1 = tf.data.TextLineDataset(train_csv_file1).map(_parse_csv_train)
dataset2 = tf.data.TextLineDataset(train_csv_file2).map(_parse_csv_train)
dataset3 = tf.data.TextLineDataset(train_csv_file3).map(_parse_csv_train)

我只是对以下内容感到好奇:

假设我们有数据集:

dataset = tf.data.TextLineDataset(train_csv_file).map(_parse_csv_train)

其中包含来自 3 个类的所有数据,

有没有办法调用一些像dataset.selectDataByLabel(label=="2")这样的函数[这是一个虚构的函数],以便我可以根据它们的标签将数据集分成 3 个部分?

所以最后我选择按csvs分隔文件,即生成csvs,每个csv只包含一个类的数据。当类太多时,这可能不是一个完美的解决方案,但就我而言,只有 5 个类,所以没关系。

最新更新