如何使用TFF访问标签



我遵循这个图像分类教程和文本生成教程。所以我已经在我的数据集上实现了迁移学习,但我不知道在我做预测的时候如何访问标签。我将数据转换成正确的形状(tf.data.Dataset),因此我使用Keras模型进行预测。例如,如果我只想预测一个标签:keras_model.predict(federated_train_data[0])

federated_train_data包含以下元素:

(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(None,), dtype=tf.int64, name=None))

第一个张量是图像形状,第二个张量表示编码标签。

我的目标是说明什么是图像的真实和预测标签,例如:(预测类)

TLDR:当你有tf.data.Dataset时,是否有一种方法可以访问标签?

如果federated_train_datatf.data.Dataset.element_spec属性返回:

(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(None,), dtype=tf.int64, name=None))

则可以遍历数据集:

# Get the first batch
first_batch = next(iter(federated_train_data)) 
# Examine all batches
for batch in federated_train_data:
print(batch)

.element_spec我们知道每批是(features, labels)的一个2元组,所以我们可以使用第二个索引获得标签:

labesl = first_batch[1]
# Or unpack
features, labels = first_batch

结合模型预测:

for batch in federated_train_data:
features, labels = batch
predictions = keras_model.predict(features)
# Now we have all three pieces: features, labels, and predictions.

最新更新