在tensorflow API的ParameterServerTraining教程代码中,在model.fit
节
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)
而且据说
The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.
这是否意味着数据集必须在每个工作服务器的相同存储上(假设参数服务器和工作服务器是不同的机器)?
或者在一台机器上的参数服务器是否有任何方式可以将数据发送给工人进行培训,而不需要工人机器直接将数据集存储在ParameterServerStrategy中,我不理解?
为了社区的利益在这里回答。
从注释部分:
(如果有人有同样的疑问)经过进一步的研究,我发现,我们可以在存在参数server的1台服务器上启动协调器我们可以使用
tf.distribute.Server()
,减少呼叫或训练呼叫从协调器。点击此链接tensorflow.org/api_docs/python/tf/distribute/Server