我正在尝试修改代码以使用多个 GPU 训练我的转换器图像标题模型以进行平行主义。我不知道"列表"指的是什么



当我字符串修改代码以使用多个gpu训练我的图像字幕模型时,会出现错误。我不知道"列表"指的是什么。也许是输入的问题,但我不知道为什么是错误的。

Meshed-Memory Transformer Training
Let's use 3 GPUs!
Training starts
Epoch 0 - train:   0%|                                                                                       | 0/9440 [00:08<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 257, in <module>
train_loss = train_xe(model, dataloader_train, optim, text_field)
File "train.py", line 82, in train_xe
out = model(detections, captions)
File "/data/zzw/anaconda3/envs/m2release/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/data/zzw/anaconda3/envs/m2release/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/data/zzw/anaconda3/envs/m2release/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
return replicate(module, device_ids)
File "/data/zzw/anaconda3/envs/m2release/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 174, in replicate
replica._buffers[key] = buffer_copies[j][buffer_idx]
IndexError: list index out of range

我知道这已经有一段时间了,但对于仍在挣扎的人来说,我在使用torch==1.1.0时也遇到了同样的问题。我通过将我的手电筒升级为1.7.1解决了这个问题

相关内容

  • 没有找到相关文章