我试图在TensorFlow中实现本文中的CSRA模块,而不是PyTorch,使用MobileNetV3作为特征提取器。
我只精通TensorFlow,我在从PyTorch转换到TensorFlow时遇到了一些麻烦。
实现该模块的官方源代码使用了一些在TensorFlow中不常用的函数,这些函数的tf文档提供的帮助很少。
源代码如下:
class CSRA(nn.Module): # one basic block
def __init__(self, input_dim, num_classes, T, lam):
super(CSRA, self).__init__()
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
# x (B d H W)
# normalize classifier
# score (B C HxW)
score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit
我不需要将CSRA定义为类,我试图将模块拼接在一起作为keras功能模型。下面是我整理的代码:
inputs = tf.keras.Input(shape=(None, None, 3))
head = tf.keras.layers.Conv2D(2, kernel_size=1, padding='same', use_bias=False, input_shape=(None, None, None, 960))
features = base_model(inputs, training=False)
print(head.get_weights())
score = head(features) / tf.transpose((tf.linalg.normalize(head.get_weights(), axis=3)), perm=(0,1))
shape = scores.get_shape().as_list()
score = tf.reshape(score, [-1, shape[1] * shape[2], shape[3]])
# scores = tf.reshape(scores[1:2])
avg_scores = tf.keras.backend.mean(score, axis=1)
max_scores_act = tf.keras.activations.softmax(score, axis=1)
max_scores = tf.math.reduce_sum(max_scores_act * score, axis=1)
outputs = (avg_scores + max_scores*0.2)
model = tf.keras.Model(inputs, outputs)
我遇到的主要问题是在TensorFlow中获得score
的步骤。我目前面临的问题似乎是,head.get_weights()
返回的对象是一个空列表。我明白,因为head
是"空的",它的权重是零,但如何定义score
呢?
除此之外,我对翻译有很多疑问。由于pyTorch使用"通道优先";和TensorFlow "通道最后",我不得不改变在每个步骤中受到影响的暗,但我不知道我是否做得对,如果PyTorch中的所有操作都是TensorFlow所必需的。
在对python文档进行了相当多的挖掘之后,我基本弄清楚了
让我失望的是,代码不是使用预制层,而是从底层开始实现层
关于对FC层权重的访问,pytorch的weight
属性与tensorflow的get_weights()方法之间存在根本性的区别。前者返回层输出的权重,而后者返回层的权重。无论如何,这一步对于权值归一化是必要的,这在tensorflow