如何访问序列类型的值?



client_output

中有以下属性
weights_delta = attr.ib()
client_weight = attr.ib()
model_output = attr.ib()
client_loss = attr.ib() 

之后,我将client_output以序列的形式通过a = tff.federated_collect(client_output)round_model_delta = tff.federated_map(selecting_fn,a)在这里。我宣布'

@tff.tf_computation()  # append
def selecting_fn(a):
#TODO
return round_model_delta

在这里。在服务器上平均的过程中,我想通过选择一些损失值较小的客户端来平均weights_delta。所以我试着通过a.weights_delta访问它,但它不起作用。

tff.federated_collect返回一个位于tff.SERVERtff.SequenceType,您可以以相同的方式操作,例如客户端数据集通常在tff.tf_computation修饰的方法中处理。

注意,必须在tff.federated_computation的作用域中使用tff.federated_collect操作符。您可能想要做的[*]是使用tff.federated_map操作符将其传递给tff.tf_computation。进入tff.tf_computation后,您可以将其视为tf.data.Dataset对象,并且tf.data模块中的所有内容都可用。

我猜。更详细地说明你想达到的目标将会有所帮助。

相关内容

  • 没有找到相关文章

最新更新