与张量流2的深度互相关



我想用tensorflow 2和keras实现SiamRPN++中描述的深度互相关层。它应该是keras层的一个子类,以允许灵活的使用。我的实现编译正确,但在训练tensorflow时抛出错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:从形状为[24,8,32]的张量中指定了形状为[8,24,32]的列表

这是我的代码。我做错了什么?

class CrossCorr(Layer):
"""
Implements the cross correlation laer of siam_rpn_plus_plus
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, inputs):
super(CrossCorr, self).build(inputs)  # Be sure to call this at the end
def call(self, inputs, **kwargs):
def _corr(search_img, filter):
x = tf.expand_dims(search_img, 0)
f = tf.expand_dims(filter, -1)
# use the feature map as kernel for the depthwise conv2d of tensorflow
return tf.nn.depthwise_conv2d(input=x, filter=f, strides=[1, 1, 1, 1], padding='SAME')

# Iteration over each batch
out = tf.map_fn(
lambda filter_simg: _corr(filter_simg[0], filter_simg[1]),
elems=inputs,
dtype=inputs[0].dtype
)
return tf.squeeze(out, [1])
def compute_output_shape(self, input_shape):
return input_shape

我用来称呼它

def _conv_block(inputs, filters, kernel, strides, kernel_regularizer=None):
x = Conv2D(filters, kernel, padding='same', strides=strides, 
kernel_regularizer=kernel_regularizer)(inputs)
x = BatchNormalization()(x)
return Activation(relu)(x)

def cross_correlation_layer(search_img, template_img, n_filters=None):
n_filters = int(search_img.shape[-1]) if n_filters is None else n_filters
tmpl = _conv_block(template_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))
search = _conv_block(search_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))
# calculate cross correlation by striding the generated "filter" over the image in depthwise manner
cc = CrossCorr()([search, tmpl])
# 1D conv to make it a seperable convolution
cc = Conv2D(filters=n_filters, kernel_size=1, strides=1)(cc)
# apply one more filter over it
fusion = _conv_block(cc, n_filters, 3, 1)
return fusion

在尝试运行您的代码后,我意识到Layer(您通过_conv_block调用它(需要一组批处理的图像,下面是在您的cross_correlation_Layer函数中从上面修改的

def cross_correlation_layer(search_img, template_img, n_filters=None):
n_filters = int(search_img.shape[-1]) if n_filters is None else n_filters
template_img = tf.expand_dims(template_img, -1)
search_img = tf.expand_dims(search_img, -1)
tmpl = _conv_block(template_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))
search = _conv_block(search_img, n_filters, 3, 1, kernel_regularizer=L1L2(1e-5, 1e-4))

你的其余功能保持不变,希望这能有所帮助!(与您的问题无关,但单词"filter"是Python中的保留关键字,我们尽量避免将其用作参数名称(

最新更新