我如何实现函数isBatched
,以便它测试参数dataset
是否批处理?
import tensorflow as tf
print(tf.__version__)
def isBatched(dataset):
# I guess this is what @yudhiesh means
batch = next(iter(dataset))
return batch.shape.ndims > 0 and batch.shape[0] > 0
tensor1 = tf.range(100)
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
assert isBatched(dataset.batch(10)) == True, "T1 fails"
assert isBatched(dataset.batch(10).map(lambda x: x)) == True, "T2 fails"
assert isBatched(dataset.batch(10).filter(lambda x: True).xxx.yyy.zzz) == True, "T3 fails"
assert isBatched(dataset.repeat()) == False, "T4 fails"
tensor2 = tf.random.uniform([10, 10])
dataset = tf.data.Dataset.from_tensor_slices(tensor2)
assert isBatched(dataset) == False, "T5 fails"
不必考虑。batch().unbatch()的情况
我检查了是否有办法找到tf.data的批处理大小。数据集,它似乎要求最后一次调用是.batch()。在我的例子中,.batch可以出现在调用链中的任何地方。
如何从tensorflow数据集获得批处理大小?假设第一个维度是批。如果原始数据集是多维的,它将不起作用。
请把代码给我看看,因为我正在为明天的学生准备讲座。
如果你有一个数据集:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1)
你可以检查输入数据集,看看它是否是一个BatchDataset
:
x._input_dataset.__class__.__name__
'BatchDataset'
它是,所以它将有一个_batch_size
属性:
x._input_dataset._batch_size
<tf.Tensor: shape=(), dtype=int64, numpy=4>
也许最后第二个操作不是批处理,所以你可能需要使用_input_dataset
迭代地找到批处理数据集,像这样:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1).map(lambda x: x).cache()
x._input_dataset._input_dataset._input_dataset.__class__.__name__
'BatchDataset'
那么下面的解决方案在所有情况下都有效吗?
def labels_from_dataset(dataset):
if not isinstance(dataset, tf.data.Dataset):
raise TypeError('dataset is not a tf.data.Dataset')
input_dataset = dataset._input_dataset
while not hasattr(input_dataset, '_batch_size') and hasattr(input_dataset, '_input_dataset'):
input_dataset = input_dataset._input_dataset
if hasattr(input_dataset, '_batch_size'):
dataset = dataset.unbatch()
y_labels = []
for _, labels in dataset:
y_labels.append(labels.numpy())
return y_labels