Permute Layer: 2减去3的负维度大小



我有两个传感器输入,我之前已经应用了Concatenate层进行融合。它们都是时间序列数据,我现在尝试应用一个排列层。然而,当我这样做时,我得到错误:

对于输入形状为[?]的"{{node conv1d_334/conv1d}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1,1,1,1], explicit_paddings=[], padding="VALID", strides=[1,1,1,1], use_cudnn_on_gpu=true](conv1d_334/conv1d/ExpandDims_1, conv1d_334/conv1d/ExpandDims_1)",由2减去3引起的负维度大小:[?], 1 2 6249],[1、3、6249、128]。

我的输入都是输入维数为(1176, 6249, 1)的时间序列数据。有人能告诉我我哪里做错了吗?下面是一个示例代码:

lr = 0.0005
n_timesteps = 3750
n_features = 1
n_outputs = 3
def small_model(optimizer='rmsprop', init='glorot_uniform'):
signal1 = Input(shape=(X_train.shape[1:]))
signal2 = Input(shape=(X_train_phase.shape[1:]))

concat_signal = Concatenate()([signal1, signal2])

# x = InputLayer(input_shape=(None, X_train.shape[1:][0],1))(inputA)

x = Permute(dims=(2, 1))(concat_signal)
x = BatchNormalization()(x)
x = Conv1D(64, 5, activation='relu', kernel_initializer='glorot_normal')(x) #, input_shape=(None, 3750, n_features)
x = Conv1D(64, 5, activation='relu', kernel_initializer='glorot_normal')(x)
x = MaxPooling1D(5)(x)
x = Dropout(0.3)(x)

你的问题是,当你得到卷积时,你的时间维度(2)小于你指定的过滤器(5)。

import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Permute
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv1D

# dummy data w/batch 32
X_train = tf.random.normal([32, 6249, 1])
X_train_phase = tf.random.normal([32, 6249, 1])

signal1 = Input(shape=(X_train.shape[1:]))
signal2 = Input(shape=(X_train_phase.shape[1:]))
concat_signal = Concatenate()([signal1, signal2])
x = Permute(dims=(2, 1))(concat_signal)
x = BatchNormalization()(x)
print(x.shape)
# (None, 2, 6249)

如果你看到tf.keras.layers.Conv1D的文档,你会注意到"valid"是默认的填充,这意味着没有填充。有很好的参考资料,《深度卷积算法指南》learn ",它很好地说明了输入大小、内核大小、跨步和填充之间的关系。

虽然我不确定你想用这个网络完成什么,但将参数padding="same"添加到你的卷积层将毫无问题地发送输入。

x = Conv1D(
filters=64,
kernel_size=5,
activation="relu",
padding="same",  # <= add this.
kernel_initializer="glorot_normal")(x)

最新更新