我正在尝试学习两个图像嵌入之间的相似性矩阵(M),单个训练实例是一对图像-(锚,正)。因此,理想情况下,模型将为相似图像的嵌入返回0距离。
问题是,当我将距离矩阵(M)声明为tf时。变量,则返回错误在这条线上self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
TypeError: 'Variable' object is not iterable.
我认为我应该为M使用一个tensorflow数据类型,它是可迭代的
请告诉我如何解决这个问题
import tensorflow as tf
from tensorflow import keras
# metric learning model
class MetricLearningModel:
def __init__(self, lr):
self.optimizer = keras.optimizers.Adam(lr=lr)
self.lr = lr
self.loss_object = keras.losses.MeanSquaredError()
self.trainable_variables = tf.Variable(
(tf.ones((2048, 2048), dtype=tf.float32)),
trainable=True
)
def similarity_function(self, anchor_embeddings, positive_embeddings):
M = self.trainable_variables
X_i = anchor_embeddings
X_j = positive_embeddings
similarity_value = tf.matmul(X_j, M, name='Tensor')
similarity_value = tf.matmul(similarity_value, tf.transpose(X_i), name='Tensor')
# distance(x,y) = sqrt( (x-y)@M@(x-y).T )
return similarity_value
def train_step(self, anchor, positive):
anchor_embeddings, positive_embeddings = anchor, positive
# Calculate gradients
with tf.GradientTape() as tape:
# Calculate similarity between anchors and positives.
similarities = self.similarity_function(anchor_embeddings, positive_embeddings)
y_pred = similarities
y_true = tf.zeros(1)
print(y_true, y_pred)
loss_value = self.loss_object(
y_pred=y_true,
y_true=y_pred,
)
gradients = tape.gradient(loss_value, self.trainable_variables)
# Apply gradients via optimizer
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
metric_model = MetricLearningModel(lr=1e-3)
anchor, positive = tf.ones((1, 2048), dtype=tf.float32), tf.ones((1, 2048), dtype=tf.float32)
metric_model.train_step(anchor, positive)
pythonzip
函数需要可迭代的对象,例如列表或元组。
在你调用tape.gradient
,或optimizer.apply_gradients
,你可以把你的变量在一个列表来解决问题:
with tf.GradienTape() as tape:
gradients = tape.gradient(loss_value, [self.trainable_variables])
# Apply gradients via optimizer
self.optimizer.apply_gradients(zip(gradients, [self.trainable_variables]))
tape.gradient
尊重传递给计算梯度的sources
对象的形状,所以如果你给它一个列表,你将从中得到一个列表。在文档中有说明:
的回报张量的列表或嵌套结构(或IndexedSlices,或None),源中的每个元素一个。返回的结构与源的结构相同。