我想利用我在towardsdastascience发现的一个有前途的NN进行案例研究。
我拥有的数据形状是:
X_train:(1200,18,15)
y_train:(1200,18,1)
在这里,NN在其他层中具有GRU、Flatten和Dense。
def twds_model(layer1=32, layer2=32, layer3=16, dropout_rate=0.5, optimizer='Adam'
, learning_rate=0.001, activation='relu', loss='mse'):
model = Sequential()
model.add(Bidirectional(GRU(layer1, return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(AveragePooling1D(2))
model.add(Conv1D(layer2, 3, activation=activation, padding='same',
name='extractor'))
model.add(Flatten())
model.add(Dense(layer3,activation=activation))
model.add(Dropout(dropout_rate))
model.add(Dense(1))
model.compile(optimizer=optimizer,loss=loss)
return model
twds_model=twds_model()
print(twds_model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bidirectional_4 (Bidirection (None, 18, 64) 9216
_________________________________________________________________
average_pooling1d_1 (Average (None, 9, 64) 0
_________________________________________________________________
extractor (Conv1D) (None, 9, 32) 6176
_________________________________________________________________
flatten_1 (Flatten) (None, 288) 0
_________________________________________________________________
dense_3 (Dense) (None, 16) 4624
_________________________________________________________________
dropout_4 (Dropout) (None, 16) 0
_________________________________________________________________
dense_4 (Dense) (None, 1) 17
=================================================================
Total params: 20,033
Trainable params: 20,033
Non-trainable params: 0
_________________________________________________________________
None
不幸的是,我陷入了一种矛盾的错误陷阱,输入和输出形状不匹配。这里的错误在上面的情况下。
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]
[[{{node loss_2/dense_4_loss/sub}}]]
[[{{node loss_2/mul}}]]
Train on 10420 samples, validate on 1697 samples
Epoch 1/8
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-30-3f5256ff03ec> in <module>
----> 1 Test_tdws=twds_model.fit(X_train, y_train, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
878 initial_epoch=initial_epoch,
879 steps_per_epoch=steps_per_epoch,
--> 880 validation_steps=validation_steps)
881
882 def evaluate(self,
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)
327
328 # Get outputs.
--> 329 batch_outs = f(ins_batch)
330 if not isinstance(batch_outs, list):
331 batch_outs = [batch_outs]
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasbackend.py in __call__(self, inputs)
3074
3075 fetched = self._callable_fn(*array_vals,
-> 3076 run_metadata=self.run_metadata)
3077 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3078 return nest.pack_sequence_as(self._outputs_structure,
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonclientsession.py in __call__(self, *args, **kwargs)
1437 ret = tf_session.TF_SessionRunCallable(
1438 self._session._session, self._handle, args, status,
-> 1439 run_metadata_ptr)
1440 if run_metadata:
1441 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonframeworkerrors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
526 None, None,
527 compat.as_text(c_api.TF_Message(self.status.status)),
--> 528 c_api.TF_GetCode(self.status.status))
529 # Delete the underlying status object from memory otherwise it stays alive
530 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]
[[{{node loss_2/dense_4_loss/sub}}]]
[[{{node loss_2/mul}}]]
为了完成预期误差,y_train被重塑为(1200*18,1(:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-47-2a6d0761b794> in <module>
----> 1 Test_tdws=twds_model.fit(X_train, y_train_flat, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
774 steps=steps_per_epoch,
775 validation_split=validation_split,
--> 776 shuffle=shuffle)
777
778 # Prepare validation data.
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle)
2434 # Check that all arrays have the same length.
2435 if not self._distribution_strategy:
-> 2436 training_utils.check_array_lengths(x, y, sample_weights)
2437 if self._is_graph_network and not self.run_eagerly:
2438 # Additional checks to avoid users mistakenly using improper loss fns.
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining_utils.py in check_array_lengths(inputs, targets, weights)
454 'the same number of samples as target arrays. '
455 'Found ' + str(list(set_x)[0]) + ' input samples '
--> 456 'and ' + str(list(set_y)[0]) + ' target samples.')
457 if len(set_w) > 1:
458 raise ValueError('All sample_weight arrays should have '
ValueError: Input arrays should have the same number of samples as target arrays. Found 12117 input samples and 218106 target samples
使用的版本有:
Package Version
---------------------- --------------------
- nsorflow-gpu
-ensorflow-gpu 1.13.1
-rotobuf 3.11.3
-umpy 1.18.1
absl-py 0.9.0
antlr4-python3-runtime 4.8
asn1crypto 1.3.0
astor 0.7.1
astropy 3.2.1
astunparse 1.6.3
attrs 19.3.0
audioread 2.1.8
autopep8 1.5.3
backcall 0.1.0
beautifulsoup4 4.9.0
bezier 0.8.0
bkcharts 0.2
bleach 3.1.4
blis 0.2.4
bokeh 1.1.0
boto3 1.9.253
botocore 1.12.253
Bottleneck 1.3.2
cachetools 4.1.0
certifi 2020.4.5.1
cffi 1.14.0
chardet 3.0.4
click 6.7
cloudpickle 0.5.3
cmdstanpy 0.4.0
color 0.1
colorama 0.4.3
colorcet 0.9.1
convertdate 2.2.1
copulas 0.2.5
cryptography 2.8
ctgan 0.2.1
cycler 0.10.0
cymem 2.0.2
Cython 0.29.17
dash 0.26.0
dash-core-components 0.27.2
dash-html-components 0.11.0
dash-renderer 0.13.2
dask 0.18.1
dataclasses 0.6
datashader 0.7.0
datashape 0.5.2
datawig 0.1.10
deap 1.3.0
decorator 4.4.2
defusedxml 0.6.0
deltapy 0.1.1
dill 0.2.9
distributed 1.22.1
docutils 0.14
entrypoints 0.3
ephem 3.7.7.1
et-xmlfile 1.0.1
exrex 0.10.5
Faker 4.0.3
fastai 1.0.60
fastprogress 0.2.2
fbprophet 0.6
fire 0.3.1
Flask 1.0.2
Flask-Compress 1.4.0
future 0.17.1
gast 0.3.3
geojson 2.4.1
geomet 0.2.0.post2
google-auth 1.14.0
google-auth-oauthlib 0.4.1
google-pasta 0.2.0
gplearn 0.4.1
graphviz 0.13.2
grpcio 1.29.0
h5py 2.10.0
HeapDict 1.0.0
holidays 0.10.2
holoviews 1.12.1
html2text 2018.1.9
hyperas 0.4.1
hyperopt 0.1.2
idna 2.6
imageio 2.5.0
imbalanced-learn 0.3.3
imblearn 0.0
importlib-metadata 1.5.0
impyute 0.0.8
ipykernel 5.1.4
ipython 7.13.0
ipython-genutils 0.2.0
ipywidgets 7.5.1
itsdangerous 0.24
jdcal 1.4
jedi 0.16.0
Jinja2 2.11.1
jmespath 0.9.5
joblib 0.13.2
jsonschema 3.2.0
jupyter 1.0.0
jupyter-client 6.1.2
jupyter-console 6.0.0
jupyter-core 4.6.3
Keras 2.2.5
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
keras-rectified-adam 0.17.0
kiwisolver 1.2.0
korean-lunar-calendar 0.2.1
librosa 0.7.2
llvmlite 0.32.1
lml 0.0.1
locket 0.2.0
LunarCalendar 0.0.9
Markdown 2.6.11
MarkupSafe 1.1.1
matplotlib 3.2.1
missingpy 0.2.0
mistune 0.8.4
mkl-fft 1.0.15
mkl-random 1.1.0
mkl-service 2.3.0
mock 4.0.2
msgpack 0.5.6
multipledispatch 0.6.0
murmurhash 1.0.2
mxnet 1.4.1
nb-conda 2.2.1
nb-conda-kernels 2.2.3
nbconvert 5.6.1
nbformat 5.0.4
nbstripout 0.3.7
networkx 2.1
notebook 6.0.3
numba 0.49.1
numexpr 2.7.1
numpy 1.19.0
oauthlib 3.1.0
olefile 0.46
opencv-python 4.2.0.34
openpyxl 2.5.5
opt-einsum 3.2.1
packaging 20.3
pandas 1.0.3
pandasvault 0.0.3
pandocfilters 1.4.2
param 1.9.0
parso 0.6.2
partd 0.3.8
patsy 0.5.1
pbr 5.1.3
pickleshare 0.7.5
Pillow 7.0.0
pip 20.0.2
plac 0.9.6
plotly 4.7.1
plotly-express 0.4.1
preshed 2.0.1
prometheus-client 0.7.1
prompt-toolkit 3.0.4
protobuf 3.11.3
psutil 5.4.7
py 1.8.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.6.0
pycparser 2.20
pyct 0.4.5
pyensae 1.3.839
pyexcel 0.5.8
pyexcel-io 0.5.7
Pygments 2.6.1
pykalman 0.9.5
PyMeeus 0.3.7
pymongo 3.8.0
pyOpenSSL 19.1.0
pyparsing 2.4.7
pypi 2.1
pyquickhelper 1.9.3418
pyrsistent 0.16.0
PySocks 1.7.1
pystan 2.19.1.1
python-dateutil 2.8.1
pytz 2019.3
pyviz-comms 0.7.2
PyWavelets 0.5.2
pywin32 227
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 18.1.1
qtconsole 4.4.4
rdt 0.2.1
RegscorePy 1.1
requests 2.23.0
requests-oauthlib 1.3.0
resampy 0.2.2
retrying 1.3.3
rsa 4.0
s3transfer 0.2.1
scikit-image 0.15.0
scikit-learn 0.23.2
scipy 1.4.1
sdv 0.3.2
seaborn 0.9.0
seasonal 0.3.1
Send2Trash 1.5.0
sentinelsat 0.12.2
setuptools 46.3.0
setuptools-git 1.2
six 1.14.0
sklearn 0.0
sortedcontainers 2.0.4
SoundFile 0.10.3.post1
soupsieve 2.0
spacy 2.1.8
srsly 0.1.0
statsmodels 0.9.0
stopit 1.1.2
sugartensor 1.0.0.2
ta 0.5.25
tb-nightly 1.14.0a20190603
tblib 1.3.2
tensorboard 1.13.1
tensorboard-plugin-wit 1.6.0.post3
tensorflow-estimator 1.13.0
tensorflow-gpu 1.13.1
termcolor 1.1.0
terminado 0.8.3
testpath 0.4.4
text-unidecode 1.3
texttable 1.4.0
tf-estimator-nightly 1.14.0.dev2019060501
Theano 1.0.4
thinc 7.0.8
threadpoolctl 2.1.0
toml 0.10.1
toolz 0.10.0
torch 1.4.0
torchvision 0.5.0
tornado 6.0.4
TPOT 0.10.2
tqdm 4.45.0
traitlets 4.3.3
transforms3d 0.3.1
tsaug 0.2.1
typeguard 2.7.1
typing 3.6.6
update-checker 0.16
urllib3 1.22
utm 0.4.2
wasabi 0.2.2
wcwidth 0.1.9
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.34.2
widgetsnbextension 3.5.1
win-inet-pton 1.1.0
wincertstore 0.2
wrapt 1.11.2
xarray 0.10.8
xlrd 1.1.0
yahoo-historical 0.3.2
zict 0.1.3
zipp 2.2.0
提前感谢每一个指向运行代码的提示;-(!
editedit
在将tensorflow和keras更新到最新版本后,我收到了以下错误。尽管tensorlfow、CUDA 10.1和cudnn 8.0.2被完全删除并重新安装,但错误仍然存在。该错误是由我的原始代码和Fallen Aparts示例代码产生的。
UnknownError: Fail to find the dnn implementation.
[[{{node CudnnRNN}}]]
[[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]
Function call stack:
train_function -> train_function -> train_function
None
Epoch 1/4
---------------------------------------------------------------------------
UnknownError Traceback (most recent call last)
<ipython-input-1-64eb8afffe02> in <module>
27 print(twds_model.summary())
28
---> 29 twds_model.fit(X_train, y_train, epochs=4)
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
~Anaconda3envsTensorflowlibsite-packagestensorflowpythonkerasenginetraining.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerdef_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerdef_function.py in _call(self, *args, **kwds)
838 # Lifting succeeded, so variables are initialized and we can run the
839 # stateless function.
--> 840 return self._stateless_fn(*args, **kwds)
841 else:
842 canon_args, canon_kwds =
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerfunction.py in __call__(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2830
2831 @property
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerfunction.py in _filtered_call(self, args, kwargs, cancellation_manager)
1846 resource_variable_ops.BaseResourceVariable))],
1847 captured_inputs=self.captured_inputs,
-> 1848 cancellation_manager=cancellation_manager)
1849
1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerfunction.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call(
-> 1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
1926 args,
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerfunction.py in call(self, ctx, args, cancellation_manager)
548 inputs=args,
549 attrs=attrs,
--> 550 ctx=ctx)
551 else:
552 outputs = execute.execute_with_cancellation(
~Anaconda3envsTensorflowlibsite-packagestensorflowpythoneagerexecute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
UnknownError: Fail to find the dnn implementation.
[[{{node CudnnRNN}}]]
[[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]
Function call stack:
train_function -> train_function -> train_function
相应的版本列表:
Package Version
------------------------ ---------------
- nsorflow-gpu
-ensorflow-gpu 2.3.0
-rotobuf 3.11.3
absl-py 0.9.0
antlr4-python3-runtime 4.8
asn1crypto 1.3.0
astor 0.7.1
astropy 3.2.1
astunparse 1.6.3
attrs 19.3.0
audioread 2.1.8
autopep8 1.5.3
backcall 0.1.0
beautifulsoup4 4.9.0
bezier 0.8.0
bkcharts 0.2
bleach 3.1.4
blis 0.2.4
bokeh 1.1.0
boto3 1.9.253
botocore 1.12.253
Bottleneck 1.3.2
cachetools 4.1.0
certifi 2020.4.5.1
cffi 1.14.0
chardet 3.0.4
click 6.7
cloudpickle 0.5.3
cmdstanpy 0.4.0
color 0.1
colorama 0.4.3
colorcet 0.9.1
convertdate 2.2.1
copulas 0.2.5
cryptography 2.8
ctgan 0.2.1
cycler 0.10.0
cymem 2.0.2
Cython 0.29.17
dash 0.26.0
dash-core-components 0.27.2
dash-html-components 0.11.0
dash-renderer 0.13.2
dask 0.18.1
dataclasses 0.6
datashader 0.7.0
datashape 0.5.2
datawig 0.1.10
deap 1.3.0
decorator 4.4.2
defusedxml 0.6.0
deltapy 0.1.1
dill 0.2.9
distributed 1.22.1
docutils 0.14
entrypoints 0.3
ephem 3.7.7.1
et-xmlfile 1.0.1
exrex 0.10.5
Faker 4.0.3
fastai 1.0.60
fastprogress 0.2.2
fbprophet 0.6
fire 0.3.1
Flask 1.0.2
Flask-Compress 1.4.0
future 0.17.1
gast 0.3.3
geojson 2.4.1
geomet 0.2.0.post2
google-auth 1.14.0
google-auth-oauthlib 0.4.1
google-pasta 0.2.0
gplearn 0.4.1
graphviz 0.13.2
grpcio 1.29.0
h5py 2.10.0
HeapDict 1.0.0
holidays 0.10.2
holoviews 1.12.1
html2text 2018.1.9
hyperas 0.4.1
hyperopt 0.1.2
idna 2.6
imageio 2.5.0
imbalanced-learn 0.3.3
imblearn 0.0
importlib-metadata 1.5.0
impyute 0.0.8
ipykernel 5.1.4
ipython 7.13.0
ipython-genutils 0.2.0
ipywidgets 7.5.1
itsdangerous 0.24
jdcal 1.4
jedi 0.16.0
Jinja2 2.11.1
jmespath 0.9.5
joblib 0.13.2
jsonschema 3.2.0
jupyter 1.0.0
jupyter-client 6.1.2
jupyter-console 6.0.0
jupyter-core 4.6.3
Keras 2.4.3
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
keras-rectified-adam 0.17.0
kiwisolver 1.2.0
korean-lunar-calendar 0.2.1
librosa 0.7.2
llvmlite 0.32.1
lml 0.0.1
locket 0.2.0
LunarCalendar 0.0.9
Markdown 2.6.11
MarkupSafe 1.1.1
matplotlib 3.2.1
missingpy 0.2.0
mistune 0.8.4
mkl-fft 1.0.15
mkl-random 1.1.0
mkl-service 2.3.0
mock 4.0.2
msgpack 0.5.6
multipledispatch 0.6.0
murmurhash 1.0.2
mxnet 1.4.1
nb-conda 2.2.1
nb-conda-kernels 2.2.3
nbconvert 5.6.1
nbformat 5.0.4
nbstripout 0.3.7
networkx 2.1
notebook 6.0.3
numba 0.49.1
numexpr 2.7.1
numpy 1.18.5
oauthlib 3.1.0
olefile 0.46
opencv-python 4.2.0.34
openpyxl 2.5.5
opt-einsum 3.2.1
packaging 20.3
pandas 1.0.3
pandasvault 0.0.3
pandocfilters 1.4.2
param 1.9.0
parso 0.6.2
partd 0.3.8
patsy 0.5.1
pbr 5.1.3
pickleshare 0.7.5
Pillow 7.0.0
pip 20.2.2
plac 0.9.6
plotly 4.7.1
plotly-express 0.4.1
preshed 2.0.1
prometheus-client 0.7.1
prompt-toolkit 3.0.4
protobuf 3.11.3
psutil 5.4.7
py 1.8.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.6.0
pycparser 2.20
pyct 0.4.5
pyensae 1.3.839
pyexcel 0.5.8
pyexcel-io 0.5.7
Pygments 2.6.1
pykalman 0.9.5
PyMeeus 0.3.7
pymongo 3.8.0
pyOpenSSL 19.1.0
pyparsing 2.4.7
pypi 2.1
pyquickhelper 1.9.3418
pyrsistent 0.16.0
PySocks 1.7.1
pystan 2.19.1.1
python-dateutil 2.8.1
pytz 2019.3
pyviz-comms 0.7.2
PyWavelets 0.5.2
pywin32 227
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 18.1.1
qtconsole 4.4.4
rdt 0.2.1
RegscorePy 1.1
requests 2.23.0
requests-oauthlib 1.3.0
resampy 0.2.2
retrying 1.3.3
rsa 4.0
s3transfer 0.2.1
scikit-image 0.15.0
scikit-learn 0.23.2
scipy 1.4.1
sdv 0.3.2
seaborn 0.9.0
seasonal 0.3.1
Send2Trash 1.5.0
sentinelsat 0.12.2
setuptools 46.3.0
setuptools-git 1.2
six 1.14.0
sklearn 0.0
sortedcontainers 2.0.4
SoundFile 0.10.3.post1
soupsieve 2.0
spacy 2.1.8
srsly 0.1.0
statsmodels 0.9.0
stopit 1.1.2
sugartensor 1.0.0.2
ta 0.5.25
tb-nightly 1.14.0a20190603
tblib 1.3.2
tensorboard 2.3.0
tensorboard-plugin-wit 1.7.0
tensorflow-gpu 2.3.0
tensorflow-gpu-estimator 2.3.0
termcolor 1.1.0
terminado 0.8.3
testpath 0.4.4
text-unidecode 1.3
texttable 1.4.0
Theano 1.0.4
thinc 7.0.8
threadpoolctl 2.1.0
toml 0.10.1
toolz 0.10.0
torch 1.4.0
torchvision 0.5.0
tornado 6.0.4
TPOT 0.10.2
tqdm 4.45.0
traitlets 4.3.3
transforms3d 0.3.1
tsaug 0.2.1
typeguard 2.7.1
typing 3.6.6
update-checker 0.16
urllib3 1.22
utm 0.4.2
wasabi 0.2.2
wcwidth 0.1.9
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.34.2
widgetsnbextension 3.5.1
win-inet-pton 1.1.0
wincertstore 0.2
wrapt 1.11.2
xarray 0.10.8
xlrd 1.1.0
yahoo-historical 0.3.2
zict 0.1.3
zipp 2.2.0
我无法重现您的错误,请检查以下代码是否适用:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, GRU, Bidirectional, AveragePooling1D, Dense, Flatten, Dropout
import numpy as np
def twds_model(layer1=32, layer2=32, layer3=16, dropout_rate=0.5, optimizer='Adam',
learning_rate=0.001, activation='relu', loss='mse'):
model = Sequential()
model.add(Bidirectional(GRU(layer1, return_sequences=True), input_shape=(X_train.shape[1], X_train.shape[2])))
model.add(AveragePooling1D(2))
model.add(Conv1D(layer2, 3, activation=activation, padding='same',
name='extractor'))
model.add(Flatten())
model.add(Dense(layer3, activation=activation))
model.add(Dropout(dropout_rate))
model.add(Dense(1))
model.compile(optimizer=optimizer, loss=loss)
return model
if __name__ == '__main__':
X_train = np.random.rand(1200, 18, 15)
y_train = np.random.rand(1200, 18, 1)
twds_model = twds_model()
print(twds_model.summary())
twds_model.fit(X_train, y_train, epochs=20)
好吧,这是对我有效的方法:
Tensorflow 2.3.0
Keras 2.4.2
CUDA 10.1
cuDNN 7.6.5
与此代码片段一起从github问题中检索
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Set to -1 if CPU should be used CPU = -1 , GPU = 0
gpus = tf.config.experimental.list_physical_devices('GPU')
cpus = tf.config.experimental.list_physical_devices('CPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
elif cpus:
try:
# Currently, memory growth needs to be the same across GPUs
logical_cpus= tf.config.experimental.list_logical_devices('CPU')
print(len(cpus), "Physical CPU,", len(logical_cpus), "Logical CPU")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
非常感谢一直陪伴在我身边的@Fallen Apart。如果你好奇,你可能也想在这里简要了解一下我的后续问题;-(。