我从3年前的旧版本PyTorch切换到CentOS 7(基于gpu)中的稳定PyTorch 1.9,并且原始论文代码没有更改,我得到以下错误。有什么快速解决办法吗?
(fashcomp) [jalal@goku fashion-compatibility]$ python main.py --name test_baseline --learned --l2_embed --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
+ Number of params: 3191808
Traceback (most recent call last):
File "main.py", line 322, in <module>
main()
File "main.py", line 167, in main
train(train_loader, tnet, criterion, optimizer, epoch)
File "main.py", line 194, in train
for batch_idx, (img1, desc1, has_text1, img2, desc2, has_text2, img3, desc3, has_text3, condition) in enumerate(train_loader):
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise
raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "mtrand.pyx", line 905, in numpy.random.mtrand.RandomState.choice
TypeError: 'dict_keys' object cannot be interpreted as an integer
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/scratch3/research/code/fashion/fashion-compatibility/polyvore_outfits.py", line 338, in __getitem__
neg_im = self.sample_negative(outfit_id, pos_im, item_type)
File "/scratch3/research/code/fashion/fashion-compatibility/polyvore_outfits.py", line 235, in sample_negative
choice = np.random.choice(candidate_sets)
File "mtrand.pyx", line 907, in numpy.random.mtrand.RandomState.choice
ValueError: a must be 1-dimensional or an integer
和
$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1
链接到发行库:https://github.com/mvasil/fashion-compatibility/issues/25
问题出在polyvore_outfits.py
[...]
candidate_sets = self.category2ims[item_type].keys()
attempts = 0
while item_out == item_id and attempts < 100:
choice = np.random.choice(candidate_sets)
[...]
candidate_sets
是dict.keys()
方法返回的对象。在旧版本的Python中,这是一个列表,但现在它是一个dict_keys
对象。NumPy随机模块中的choice
方法接受一个列表,但不接受dict_keys
对象。
一个简单的修复方法是显式地将candidate_sets
转换为列表,或者在创建时,
candidate_sets = list(self.category2ims[item_type].keys())
或在传递给np.random.choice
之前:
choice = np.random.choice(list(candidate_sets))
您应该将dict_keys转换为列表,如上面的注释所述:
np.random.choice(list(candidate_sets))
可能是由于NumPy的版本变更。