我有视频字幕项目的数据集。用于训练的数据集管道构建为:
dataset = tf.data.Dataset.from_tensor_slices((videos , tf.ragged.constant(captions)))
我想读取进入训练步骤的所有batch_data,它看起来像:
class VideoCaptioningModel(keras.Model):
.
.
.
def train_step(self, batch_data):
batch_img, batch_seq = batch_data
batch_loss = 0
batch_acc = 0
print('batch_data=', batch_data)
.
.
输出为:
batch_data= (<tf.Tensor 'IteratorGetNext:0' shape=(None, 28, 1536) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, None, 8) dtype=int64>)
我尝试使用print('batch_data=', batch_data.numpy())
但我得到了:
AttributeError: 'tuple' object has no attribute 'numpy'
您的数据集由videos
和captions
组成,数据集中的每个条目都是tuple
。参见:
for x in dataset:
tf.print(x[0]) # videos
tf.print(x[1]) # captions
现在,请注意,可以在Eager Execution
模式下对tf.Tensor
调用.numpy()
,但元组没有此属性。所以试试:
tf.print('batch_data=', batch_data[0].numpy())
tf.print('batch_data=', batch_data[1].numpy())