我试图模仿这里提供的联合学习实现:使用tff的clientData,以便清楚地理解代码。我已经到了需要澄清的地步。
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
dataset.batch(5)
指的是什么?这些批次是从数据中提取用于训练的,而这3个批次是用于测试的吗.take(5)
是什么意思
在行中:
dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
您首先将dataset
中的样本分成5个批次。然后,将map_fn
函数应用于dataset
中的每个批次(一次5个样本(。最后,使用dataset.take(5)
,您将从dataset
返回5个批次,其中每个批次有5个样本。
在您链接的示例中,client_data
包含多个tf
数据集。