TFF RuntimeError:试图捕获一个没有构建函数的EagerTensor



我有一个TFF模型运行,但我得到了一个错误。我提供了x和y,并像教程一样继续实现它。

TF version = 2.5.1TFF版本= 0.19.0

我的代码片段是
split = len(usr_data_set)
client_train_dataset = collections.OrderedDict()
for i in range(0, split):
client_name = "client_" + str(i)
xx, y = usr_data_set[i] # shape for client one [2441, 13055], for client two 
# [2420, 13055], for client three [2451, 13055]
data = collections.OrderedDict((('x', xx), ('y', y)))

client_train_dataset[client_name] = data
train_dataset = tff.simulation.datasets.TestClientData(client_train_dataset)
sample_dataset = train_dataset.create_tf_dataset_for_client(train_dataset.client_ids[0])
sample_element = next(iter(sample_dataset))
def preprocess(dataset):
NUM_EPOCHS = 5
BATCH_SIZE = 32
PREFETCH_BUFFER = 10
def batch_format_fn(element):
return collections.OrderedDict(
x=reshape(element['x'], [-1, 13055]),
y=reshape(element['y'], [-1, 2]))

return dataset.repeat(NUM_EPOCHS).batch(BATCH_SIZE).map(
batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_sample_dataset = preprocess(sample_dataset)
# sample_batch = nest.map_structure(lambda x: x.numpy(), next(iter(preprocessed_sample_dataset)))
def make_federated_data(client_data, client_ids):
return [preprocess(client_data.create_tf_dataset_for_client(x)) for x in client_ids]

# return make_federated_data(train_dataset, train_dataset.client_ids), preprocessed_sample_dataset
federated_train_data = make_federated_data(train_dataset, train_dataset.client_ids)
# federated_train_data, preprocessed_sample_dataset = tff_dataset(usr_data_set)
losses = tf.keras.losses.CategoricalCrossentropy()
metric = [tf.keras.metrics.CategoricalAccuracy()]
def CNN():
model = Sequential()
model.add(Reshape((13055, 1), input_shape=(13055,)))
model.add(Conv1D(8, kernel_size=7, padding='same', strides=3, activation='relu'))
model.add(MaxPooling1D(4, strides=2, padding='same'))
model.add(Conv1D(128, kernel_size=7, padding='same', strides=3, activation='relu'))
model.add(MaxPooling1D(4, strides=2, padding='same'))
model.add(Conv1D(64, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(MaxPooling1D(4, strides=2, padding='same'))
model.add(Conv1D(64, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(MaxPooling1D(4, strides=2, padding='same'))
model.add(Flatten())
model.add(Dense(units=64, activation='relu'))
model.add(Dense(units=64, activation='relu'))
model.add(Dense(units=2, activation='softmax'))
return model

def model_fn():
keras_model = CNN()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_sample_dataset.element_spec,
loss=losses,
metrics=metric)
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
print(str(iterative_process.initialize.type_signature))

我读了另一个关于这个错误的帖子,但我的所有函数都在model_fn的范围内,我看不到任何其他问题。

完整的脚本错误是这样的,

File "/Users/amir/Documents/CODE/Python/FedGS/tff_dataset.py", line 175, in <module>
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/federated_averaging.py", line 270, in build_federated_averaging_process
model_update_aggregation_factory=model_update_aggregation_factory)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 631, in build_model_delta_optimizer_process
model_weights_type = model_utils.weights_type_from_model(model_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/model_utils.py", line 100, in weights_type_from_model
model = model()
File "/Users/amir/Documents/CODE/Python/FedGS/tff_dataset.py", line 170, in model_fn
metrics=metric)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 175, in from_keras_model
metrics=metrics))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 304, in __init__
tf.TensorSpec.from_tensor, self.report_local_outputs())
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
result = self._call(*args, **kwds)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 957, in _call
filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1974, in _call_flat
flat_outputs = forward_function.call(ctx, args_with_tangents)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 625, in call
executor_type=executor_type)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1189, in partitioned_call
args = [ops.convert_to_tensor(x) for x in args]
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1189, in <listcomp>
args = [ops.convert_to_tensor(x) for x in args]
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/profiler/trace.py", line 163, in wrapped
return func(*args, **kwargs)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1525, in convert_to_tensor
raise RuntimeError("Attempting to capture an EagerTensor without "
RuntimeError: Attempting to capture an EagerTensor without building a function.

谁能帮我解决这个问题?我想尽办法想解决这个问题,但没有成功。

我相信您需要model_fn中创建由lossesmetric变量持有的对象。像这样:

def model_fn():
keras_model = CNN()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_sample_dataset.element_spec,
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()])

问题是Keras指标通常创建需要在序列化中捕获的tf.Variable

最新更新