我是Tensorflow新手,如果我的问题太基础或太愚蠢,请耐心等待;(
我试图在";用于语言理解的转换器模型"-Tensorflow网站教程(https://www.tensorflow.org/tutorials/text/transformer)。我的意图是让我的测试运行更快,当玩代码的时候。
我想我可以使用dataset.take(n)
方法来缩短训练数据集。我在从文件中读取原始数据集后添加了两行:
...
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True, as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
# lines added to reduce dataset size
train_examples = train_examples.take(1000)
val_examples = val_examples.take(1000)
...
所得到的数据集(即train_examples
、val_examples
(似乎具有预期的大小,并且它们似乎正在工作,例如,与turorial中的下一个标记器一起工作。
然而,当我执行代码时,更具体地说,当它进入训练(即train_step(inp, tar)
(时,我会收到大量的错误消息和警告。错误消息和警告太长,无法在此处复制,但其中可能有一个重要部分:
...
/home/kst/python/tf/tf_env/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1105 set_shape
raise ValueError(
ValueError: Tensor's shape (8216, 128) is not compatible with supplied shape (4870, 128)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
...
训练部分中的一些张量的大小或形状似乎不合适。
.take(n)
不是Tensorflow中缩短数据集的好方法,有充分的理由吗?
有更好的方法吗?
谢谢!:(
我发现了问题!导致错误消息的不是.take()
方法。是检查点经理:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(transformer=transformer,
optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
我的磁盘上似乎有旧的检查点,这些检查点是在使用较大数据集运行时创建的。检查点管理器自动恢复了旧的检查点,并希望从那里继续。当然,旧张量的大小和形状与新张量不匹配(例如,词汇大小不同(。这创建了错误消息。当我删除旧的检查点(I.d.,.ipynb_checkpoints
目录(时,一切都很顺利!:-(