删除张量切片作为TensorFlow层的一部分



版本号:tensorflow 2.3.0, numpy 1.18.5, python 3.8.2

我想在我的TensorFlow模型中使用第一层删除输入张量的一些选定切片。例如,我有一个(180, 90, 25)的输入形状(其中180是批大小),我想从最后一个维度删除索引indices = [3, 4, 5, 6, 7, 22, 23, 24]的列表,这样,在调用输入张量上的这一层之后,我将得到一个形状为(180, 90, 25 - len(indices))的张量,其中每个选定的(180, 90)形状的张量切片都通过索引最后一个维度来删除。

目前,我使用这个图层:

class RemoveSelectedIndices(tf.keras.layers.Layer):
def __init__(self, indices=[3,4,5,6,7,22,23,24]):
super(RemoveSelectedIndices, self).__init__(name="RemoveSelectedIndices")
self.indices = self.add_weight(name="indices", shape=len(indices), dtype=tf.int32, trainable=False,
initializer=lambda *args, **kwargs: indices)
def build(self, input_shape):
pass
def call(self, input_tensor):
X = tf.unstack(input_tensor, num=input_tensor.shape[-1], axis=2) # list of 25 (180, 90)-shaped slices
indices = sorted(list(self.indices.value().numpy()))
for i in reversed(indices):
del X[i]
X = tf.stack(X, axis=2) # restacking the list back together
return X

当我测试它时(通过创建numpy数组并使用tf.convert_to_tensor然后调用张量上的层),这工作得非常好,但是当我尝试使用此层作为第一层构建模型时,我得到一个错误:

import tensorflow as tf
from tensorflow.keras.layers import Input
inputs = Input(shape=(90, 25))
X = RemoveSelectedIndices()(inputs)
# gives me AttributeError: 'Tensor' object has no attribute 'numpy'
# references the line indices = sorted(list(self.indices.value().numpy()))

为什么会发生这种情况,我有什么办法可以绕过它吗?

(注意:我知道我可以对数据本身这样做,但是数据集是巨大的,除非我不得不这样做,否则我宁愿不要过多地扰乱数据集。)

提前感谢!

对于这种操作,您可以简单地使用Lambda层

to_remove = [3,4,5,6,7,22,23,24]
def select(X, to_remove):
X = tf.stack([X[...,i] for i in range(X.shape[-1]) if i not in to_remove], -1)
return X
inputs = Input(shape=(90, 25))
x = Lambda(lambda x: select(x, to_remove))(inputs)

相关内容

  • 没有找到相关文章

最新更新