


class Attention(tf.keras.Model):
def __init__(self, units):
super(Attention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, features, hidden):
# features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)
# hidden shape == (batch_size, hidden_size)
# hidden_with_time_axis shape == (batch_size, 1, hidden_size)
hidden_with_time_axis = tf.expand_dims(hidden, 1)
# score shape == (batch_size, 64, hidden_size)
score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
# attention_weights shape == (batch_size, 64, 1)
# you get 1 at the last axis because you are applying score to self.V
attention_weights = tf.nn.softmax(self.V(score), axis=1)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = attention_weights * features
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
class CNN_Encoder(tf.keras.Model):
# Since you have already extracted the features and dumped it using pickle
# This encoder passes those features through a Fully connected layer
def __init__(self, embedding_dim):
super(CNN_Encoder, self).__init__()
# shape after fc == (batch_size, 49, embedding_dim)
self.fc = tf.keras.layers.Dense(embedding_dim)
def call(self, x):
x = self.fc(x)
x = tf.nn.relu(x)
# shape of x == (batch_size, 49, embedding_dim)
return x
class RDN_Decoder(tf.keras.Model):
def __init__(self, embedding_dim, units, vocab_size):
super(RDN_Decoder, self).__init__()
self.units = units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.lstm1 = tf.keras.layers.LSTM(self.units,
self.lstm2 = tf.keras.layers.LSTM(self.units,
self.fc1 = tf.keras.layers.Dense(self.units)
self.fc2 = tf.keras.layers.Dense(vocab_size)
self.visual_attention = Attention(self.units)
self.reflective_attention = Attention(self.units)
def call(self, x, features, hidden_state1, hidden_state2):
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(hidden_state1, 1), x], axis=-1)
# passing through lstm
output1, hidden_state1, cell_state1 = self.lstm1(x)
# visual attention as a separate model
context_vector_v, attention_weights_v = self.visual_attention(features, hidden_state1)
# change hidden state dimension
hidden_state2 = tf.concat([tf.expand_dims(hidden_state2, 1), x], axis=-1)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector_v, 1), hidden_state2], axis=-1)
# passing the concatenated vector to the lstm
output2, hidden_state2, cell_state2 = self.lstm2(x)
# reflective attention as a separate model
context_vector_r, attention_weights_r = self.reflective_attention(hidden_state2, hidden_state1)
# shape == (batch_size, max_length, hidden_size)
x = self.fc1(output2)
# x shape == (batch_size * max_length, hidden_size)
x = tf.reshape(x, (-1, x.shape[2]))
# output shape == (batch_size * max_length, vocab)
x = self.fc2(x)
# pass through softmax
x = tf.nn.softmax(x)
return x, hidden_state1, hidden_state2, attention_weights_v, attention_weights_r
def reset_state(self, batch_size):
return tf.zeros((batch_size, self.units))
encoder = CNN_Encoder(embedding_dim)
decoder = RDN_Decoder(embedding_dim, units, vocab_size)


--------------------------------------------------------------------------- ResourceExhaustedError                    Traceback (most recent call last) <ipython-input-63-e33dbe296f4b> in <module>()
13     for (batch, (img_tensor, target)) in enumerate(dataset):
---> 14         batch_loss, t_loss = train_step(img_tensor, target)
15         total_loss += t_loss
13 frames <ipython-input-62-b355d0692cf8> in train_step(img_tensor, target)
15         for i in range(1, target.shape[1]):
16             # passing the features through the decoder
---> 17             predictions, hidden_state1, hidden_state2, _, _ = decoder(dec_input, features, hidden_state1, hidden_state2)
19             loss += loss_function(target[:, i], predictions)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
966           with base_layer_utils.autocast_context_manager(
967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
969           self._handle_activity_regularization(inputs, outputs)
970           self._set_mask_metadata(inputs, outputs, input_masks)
<ipython-input-57-83f30c4f738b> in call(self, x, features, hidden_state1, hidden_state2)
81         # passing the concatenated vector to the lstm
---> 82         output2, hidden_state2, cell_state2 = self.lstm2(x)
84         # reflective attention as a separate model
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
653     if initial_state is None and constants is None:
--> 654       return super(RNN, self).__call__(inputs, **kwargs)
656     # If any of `initial_state` or `constants` are specified and are Keras
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
966           with base_layer_utils.autocast_context_manager(
967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
969           self._handle_activity_regularization(inputs, outputs)
970           self._set_mask_metadata(inputs, outputs, input_masks)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in call(self, inputs, mask, training, initial_state)    1179         if can_use_gpu:    1180           last_output, outputs, new_h, new_c, runtime = gpu_lstm(
-> 1181               **gpu_lstm_kwargs)    1182         else:    1183           last_output, outputs, new_h, new_c, runtime = standard_lstm(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths)    1390       biases=array_ops.split(full_bias, 8),    1391       shape=constant_op.constant([-1]),
-> 1392       transpose_weights=True)    1393     1394   if mask is not None:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in _canonical_to_params(weights, biases, shape, transpose_weights)    1234     return array_ops.transpose(w) if transpose_weights else w    1235 
-> 1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]    1237   biases = [array_ops.reshape(x, shape) for x in biases]    1238   return array_ops.concat(weights + biases, axis=0)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in <listcomp>(.0)    1234     return array_ops.transpose(w) if transpose_weights else w    1235 
-> 1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]    1237   biases = [array_ops.reshape(x, shape) for x in biases]    1238   return array_ops.concat(weights + biases, axis=0)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in convert(w)    1232   """    1233   def convert(w):
-> 1234     return array_ops.transpose(w) if transpose_weights else w    1235     1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in transpose(a, perm, name, conjugate)    2127     else:    2128       perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
-> 2129     return transpose_fn(a, perm, name=name)    2130     2131 
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in transpose(x, perm, name)   11176         pass  # Add nodes to the TensorFlow graph.   11177     except _core._NotOkStatusException as e:
> 11178       _ops.raise_from_not_ok_status(e, name)   11179   # Add nodes to the TensorFlow graph.   11180   _, _, _op, _outputs =
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)    6651   message = e.message + (" name: " + name if name is not None else "")    6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)    6654   # pylint: enable=protected-access    6655 
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
ResourceExhaustedError: OOM when allocating tensor with shape[1024,1024] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Transpose]


如果你想尝试,请使用这个Google Colab链接(因为数据和训练代码的生成很长,我不能把所有东西都放在这里,这会很笨拙(。你只需要按顺序运行单元格。


  1. 为了获得更好的机器或租用可以为您提供更多RAM的云服务
  2. 减少你的人际网络,我尝试了以下Hyperparameters和培训工作

BATCH_SIZE = 8embedding_dim = 512units = 512。剩余的所有超参数都是相同的

