使用client_states作为状态的自定义聚合器



我想创建一个自定义聚合器,其中状态是每个客户端的唯一客户端状态。为了初始化,我可以像往常一样定义客户端状态,然后使用federated_collect来放置@SERVER,因为这是initialize_fn()想要的。我也可以在next_fn()中创建new_state。问题是一旦我不知道如何"广播"这些状态返回到客户端。通常情况下,federated_broadcast获取A@SERVER,然后复制与客户端数量相等的副本。所以对于两个客户端,它将是{A}@CLIENTS,让我们说(A A)。我想让AB@SERVER变成(A B)

我目前在聚合过程之外定义客户端状态,然后在迭代过程的run_one_round中传递这些状态。使用federated_collect从aggregator的测量中收集这些状态,然后在外部解栈。从联邦计算的外部来看,它是

server_state, train_metrics , client_states, aggregation_state = iterative_process.next(
server_state, sampled_train_data, client_states, aggregation_state)
client_states = [x for x in client_states]

在TFF

output = aggregation_process.next(aggregation_state, client_outputs.weights_delta, client_states)
new_aggregation_state = output.state
round_model_delta = output.result
new_client_states = output.measurements

在聚合器

measurements = tff.federated_collect(new_client_states)
return tff.templates.MeasuredProcessOutput(
state=new_state, result= round_model_delta, measurements=measurements)

但是我试图在聚合器中完全定义和处理这些客户端状态,以便我可以将此聚合器插入tff.learning.build_federated_averaging_process,如

iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
model_update_aggregation_factory=my_aggregation_factory)

这可能吗?如果有,那是怎么回事?

tff.federated_collect可能不是这种情况下需要的工具,它将在TFF的未来版本中被删除(参见commit #030a406)。

或者,tff.federated_computation可以同时接受@CLIENTS参数作为输入,并且返回@CLIENT放置的值作为输出. 而不是首先收集服务器上的所有值(暗示系统正在通信状态);也许最好把这些值留给客户端。

当在模拟环境中执行TFF时(例如在Colab笔记本中调用tff.Computation),T@CLIENT放置的值将作为T对象的list返回;每个客户端一个。这可以用作将来调用tff.Computation的参数。

的例子:

@tff.tf_computation(tf.int32)
def sqrt(value):
return tf.math.sqrt(tf.cast(value, tf.float32))
@tff.federated_computation(tff.types.at_clients(tf.int32))
def federated_sqrt(values):
return tff.federated_map(sqrt, values)
client_values = [1,2,3,4]
federated_sqrt(client_values)
>>> [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.4142135>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.7320508>,
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>]

重要告诫:输入和输出的顺序不一定保证相同。在存储库中的tensorflow_federated/python/examples/stateful_clients/目录中可以找到一个如何索引和跟踪调用状态的示例。

相关内容

  • 没有找到相关文章

最新更新