我尝试实现联邦学习。(使用TensorFlow联合核心)
def create_keras_model():
model = Sequential()
model.add(Conv2D(16, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu', input_shape=(226,232,1)))
model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
model.add(Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
model.add(Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(10, activation='softmax'))
return model
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
def initialize_fn():
model = model_fn()
return model.trainable_variables
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = broadcast(server_weights)
# Each client computes their updated weights.
client_weights = client_update(federated_dataset, server_weights_at_client)
# The server averages these updates.
mean_client_weights = mean(client_weights)
# The server updates its model.
server_weights = server_update(mean_client_weights)
return server_weights
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Use the client_optimizer to update the local model.
for batch in dataset:
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data
outputs = model.forward_pass(batch)
# Compute the corresponding gradient
grads = tape.gradient(outputs.loss, client_weights)
grads_and_vars = zip(grads, client_weights)
# Apply the gradient using a client optimizer.
client_optimizer.apply_gradients(grads_and_vars)
return client_weights
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
model_weights = model.trainable_variables
# Assign the mean client weights to the server model.
tf.nest.map_structure(lambda x, y: x.assign(y),
model_weights, mean_client_weights)
return model_weights
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
@tff.tf_computation(tf.float32)
def add_half(x):
return tf.add(x, 0.5)
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
return tff.federated_map(add_half, x)
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
model_weights_type = server_init.type_signature.result
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = model_fn()
return server_update(model, mean_client_weights)
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights,client_weights
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
server_state = federated_algorithm.initialize()
并在每轮之后保存server_state (weights):
for round in range(3,15):
server_state,client_weights = federated_algorithm.next(server_state, federated_train_data)
FileCheckpointManager(root_dir= '/content/drive/MyDrive',prefix='fed_per_',step= 1,keep_total= 1,keep_first= True).save_checkpoint(state=server_state,round_num=round)
现在我想将这个预训练模型用于一个新的联邦学习案例,其中CNN层的权重是固定的,只有最后3层的权重是改变的。
有人能告诉我怎么做吗?
使用for循环,您可以通过keras冻结图层。层API
layer.trainable = False