TFF:每个客户端都做一个预训练函数,而不是build_federated_averaging_process



我希望每个客户都用我在下面写的函数pretrain来训练他的模型:

def pretrain(model):

resnet_output = model.output
layer1 = tf.keras.layers.GlobalAveragePooling2D()(resnet_output)
layer2 = tf.keras.layers.Dense(units=zdim*2, activation='relu')(layer1)
model_output = tf.keras.layers.Dense(units=zdim)(layer2)
model = tf.keras.Model(model.input, model_output)  

iterations_per_epoch = determine_iterations_per_epoch()
total_iterations = iterations_per_epoch*num_epochs
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
manager = tf.train.CheckpointManager(checkpoint, pretrain_save_path, max_to_keep=10)


current_epoch = tf.cast(tf.floor(optimizer.iterations/iterations_per_epoch), tf.int64)
batch = client_data(0)
batch = client_data(0).batch(2)
epoch_loss = []
for (image1, image2) in batch:
loss, gradients = train_step(model, image1, image2)
epoch_loss.append(loss)

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

# if tf.reduce_all(tf.equal(epoch, current_epoch+1)):
print("Loss after epoch {}: {}".format(current_epoch, sum(epoch_loss)/len(epoch_loss)))
#print("Learning rate: {}".format(learning_rate(optimizer.iterations)))
epoch_loss = []
current_epoch += 1
if current_epoch % 50 == 0:
save_path = manager.save() 
print("Saved model for epoch {}: {}".format(current_epoch, save_path))
save_path = manager.save()
model.save("model.h5") 
model.save_weights("saved_weights.h5")

但正如我们所知,TFF有一个预定义的功能:

iterative_process = tff.learning.build_federated_averaging_process(...)

那么,请问我该如何继续?感谢

有几种方法可以按照类似的思路进行。

首先,需要注意的是,TFF是功能性的——可以使用诸如向文件写入/从文件读取之类的东西来管理状态(TF允许这样做(,但它不在TFF向用户公开的界面中——而涉及向文件写入或从文件读取的东西(IE,在不通过函数参数和结果传递状态的情况下操纵状态(,这充其量应该被视为一个实现细节。这是TFF不鼓励的。

然而,通过稍微重构上面的代码,我认为这种应用程序可以很好地适应TFF的编程模型。我们想定义一些类似的东西:

@tff.tf_computation
@tf.function
def pretrain_client_model(model, client_dataset):
# perhaps do dataset processing you want...
for batch in client_dataset:
# do model training
return model.weights() # or some tensor structure representing the trained model weights

一旦您的实现看起来像这样,您就可以将其连接到自定义迭代过程中。您提到的罐装函数(build_federated_averaging_process(实际上只是构造了tff.templates.IterativeProcess的一个实例;然而,您总是可以自由地编写自己的此类实例。

几个教程带我们完成了这个过程,这可能是最简单的。有关独立迭代过程实现的完整代码示例,请参阅simple_fedavg.py

相关内容

最新更新