具有2x2输入的双向GRU



我正在构建一个网络,它将字符串分解为单词,单词分解为字符,嵌入每个字符,然后通过将字符聚合为单词和单词聚合为字符串来计算该字符串的向量表示。聚合是通过双向gru层进行的。
要测试这个东西,假设我对这个字符串中的5个单词和5个字符感兴趣。在这个例子中,我的变换是:

["Some string"] -> ["Some","strin","","",""] -> 
["Some_","string","_____","_____","_____"] where _ is the padding symbol ) -> 
[[1,2,3,4,0],[1,5,6,7,8],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] (shape 5x5)

接下来我有一个嵌入层,它把每个字符变成一个长度为6的嵌入向量。所以特征变成了一个5x5x6矩阵。然后我将这个输出传递给双向gru层,并执行一些我认为在这种情况下不重要的其他操作。

问题是当我用迭代器运行它时,比如

for string in strings:
output = model(string)

它似乎工作得很好(字符串是一个从5x5片创建的tf数据集),所以它是一堆5 × 5矩阵。

然而,当我转移到训练,或者在数据集级别使用像预测这样的函数时,模型失败了:

model.predict(strings.batch(1))
ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 5, 5, 6)

据我从文档中了解,双向层以3d张量作为输入:[batch, timesteps, feature],所以在这种情况下,我的输入形状应该看起来像:[batch_size,timesteps,(5,5,6)]

那么问题是我应该对输入数据应用哪种转换来获得这种形状?

对于双向输入层,如果您使用GRU,则使用return_sequences=True,以获得三维输出。因为GRU输出是2D的,所以return_sequences会给你3D输出。对于堆叠的双向层,输入的形状应为3D。

样例代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(
layers.Bidirectional(layers.GRU(64, return_sequences=True), input_shape=(5, 10))
)
model.add(layers.Bidirectional(layers.GRU(32)))
model.add(layers.Dense(10))
model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_3 (Bidirection (None, 5, 128)            38400     
_________________________________________________________________
bidirectional_4 (Bidirection (None, 64)                41216     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
=================================================================
Total params: 80,266
Trainable params: 80,266
Non-trainable params: 0
___________________________

相关内容

  • 没有找到相关文章

最新更新