在client_output
weights_delta = attr.ib()
client_weight = attr.ib()
model_output = attr.ib()
client_loss = attr.ib()
之后,我将client_output
以序列的形式通过a = tff.federated_collect(client_output)
和round_model_delta = tff.federated_map(selecting_fn,a)
在这里。我宣布'
@tff.tf_computation() # append
def selecting_fn(a):
#TODO
return round_model_delta
在这里。在服务器上平均的过程中,我想通过选择一些损失值较小的客户端来平均weights_delta
。所以我试着通过a.weights_delta
访问它,但它不起作用。
tff.federated_collect
返回一个位于tff.SERVER
的tff.SequenceType
,您可以以相同的方式操作,例如客户端数据集通常在tff.tf_computation
修饰的方法中处理。
注意,必须在tff.federated_computation
的作用域中使用tff.federated_collect
操作符。您可能想要做的[*]是使用tff.federated_map
操作符将其传递给tff.tf_computation
。进入tff.tf_computation
后,您可以将其视为tf.data.Dataset
对象,并且tf.data
模块中的所有内容都可用。
我猜。更详细地说明你想达到的目标将会有所帮助。