我有一个巨大的TFRecord文件,其中有超过4M个条目。与整个数据集相比,它是一个非常不平衡的数据集,包含了更多的一些标签条目,而其他标签条目很少。我想过滤其中一些标签的有限数量的条目,以便拥有一个平衡的数据集。下面,你可以看到我的尝试,但从每个标签(33个不同的标签)过滤1k需要超过24小时。
import tensorflow as tf
tf.compat.as_str(
bytes_or_text='str', encoding='utf-8'
)
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset('/test.tfrecord')
dataset = dataset.with_options(ignore_order)
features, feature_lists = detect_schema(dataset)
#Decodings TFRecord serialized data
def decode_data(serialized):
X, y = tf.io.parse_single_sequence_example(
serialized,
context_features=features,
sequence_features=feature_lists)
return X['title'], y['subject']
dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))
#Filtering and concatenating the samples
def balanced_dataset(dataset, labels_list, sample_size=1000):
datasets_list = []
for label in labels_list:
#Filtering the chosen labels
locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
#appending a limited sample
datasets_list.append(locals()[label].take(sample_size))
concat_dataset = datasets_list[0]
#concatenating the datasets
for dset in datasets_list[1:]:
concat_dataset = concat_dataset.concatenate(dset)
return concat_dataset
balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)
解决此问题的一种方法是使用group_by_window
方法,其中window_size
将是每个类的sample size
(在您的情况下为1k)。
ds = ds.group_by_window(
# Use label as key
key_func=lambda _, l: l,
# Convert each window to a sample_size
reduce_func=lambda _, window: window.batch(sample_size),
# Use window size as sample_size
window_size=sample_size)
这将形成尺寸为sample_size
的单个类别的批次。但有一个问题,同一类中会有多个批次,但每个类中只需要一个批次。
为了解决上述问题,我们需要为每个批添加一个count
,然后过滤掉count==0
,它将获取所有类的第一批。
让我们定义一个例子:
labels = np.array(sum([[label]*repeat for label, repeat in zip([0, 1, 2], [100, 200, 15])], []))
features = np.arange(len(labels))
np.unique(labels, return_counts=True)
#(array([0, 1, 2]), array([100, 200, 15]))
# There are 3 labels chosen for simplicity and each of their counts are shown along.
sample_size = 15 # we choose to pick sample of 15 from each class
我们根据上述输入创建了一个数据集,
ds = tf.data.Dataset.from_tensor_slices((features, labels))
在上面的窗口函数中,我们修改reduce_func
以制作计数器,因此批次将有3个元素(X_batch, y_batch, label_counter)
:
def reduce_func(x, y):
#class_count[y] += 1
z = table.lookup(x)
table.insert(x, z+1)
return y.batch(sample_size).map(lambda a,b: (a, b, z))
# Group by window
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.group_by_window(
# Use label as key
key_func=lambda _, l: l,
# Convert each window to a sample_size
reduce_func=reduce_func,
# Use window size as sample_size
window_size=sample_size)
CCD_ 10中的CCD_ 9逻辑被实现为CCD_。其初始化如下所示:
n_classes = 3
keys = tf.range(0,n_classes, dtype=tf.int64)
vals = tf.zeros_like(keys, dtype=tf.int64)
table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.int64,
value_dtype=tf.int64,
default_value=-1)
table.insert(keys, vals)
现在我们过滤掉count==0
所在的批次,并删除计数元素以形成(X,y)批次对:
ds = ds.filter(lambda x, y, count: count==0)
ds = ds.map(lambda x, y, count: (x, y))
输出,
for x, y in ds:
print(x.numpy(), y.numpy())
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[100 101 102 103 104 105 106 107 108 109 110 111 112 113 114] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[300 301 302 303 304 305 306 307 308 309 310 311 312 313 314] [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]