InvalidArgumentError:while_1/Merge_1节点的输入1从while_1/NextItera



我有一个TensorArray(a(来存储tf.while_oop中计算的值。但是,我无法将TensorArray转换为Numpy数组。出于某种原因,int32和float32之间似乎不匹配。

import time
import tensorflow as tf
import numpy as np

#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)

x_batch = tf.convert_to_tensor(x_train)
s_pred_im = tf.convert_to_tensor(x_batch)
iters = tf.constant(10)
a = tf.TensorArray(tf.float32, size=10)
def cond(value, a, s_pred_im, x_batch, i, iters):
return tf.less(i, iters)
def body(value, a, s_pred_im, x_batch, i, iters):
value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
a = a.write(i,value)
return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]
res = tf.while_loop(cond, body, [0, a, s_pred_im, x_batch, 0, iters])
b = res[1].stack()
with tf.Session() as sess:
b.eval()

这样做会产生以下错误-

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1364     try:
-> 1365       return fn(*args)
1366     except errors.OpError as e:
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1347       # Ensure any changes to the graph are reflected in the runtime.
-> 1348       self._extend_graph()
1349       return self._call_tf_sessionrun(options, feed_dict, fetch_list,
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _extend_graph(self)
1387     with self._graph._session_run_lock():  # pylint: disable=protected-access
-> 1388       tf_session.ExtendSession(self._session)
1389 
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.
During handling of the above exception, another exception occurred:
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-72-5642d29d3bf6> in <module>
1 with tf.Session() as sess:
----> 2     b.eval()
~anaconda3envstestlibsite-packagestensorflow_corepythonframeworkops.py in eval(self, feed_dict, session)
796 
797     """
--> 798     return _eval_using_default_session(self, feed_dict, self.graph, session)
799 
800   def experimental_ref(self):
~anaconda3envstestlibsite-packagestensorflow_corepythonframeworkops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
5405                        "the tensor's graph is different from the session's "
5406                        "graph.")
-> 5407   return session.run(tensors, feed_dict)
5408 
5409 
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in run(self, fetches, feed_dict, options, run_metadata)
954     try:
955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
957       if run_metadata:
958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1178     if final_fetches or final_targets or (handle and feed_dict_tensor):
1179       results = self._do_run(handle, final_targets, final_fetches,
-> 1180                              feed_dict_tensor, options, run_metadata)
1181     else:
1182       results = []
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1357     if handle is None:
1358       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359                            run_metadata)
1360     else:
1361       return self._do_call(_prun_fn, handle, feeds, fetches)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1382                     'nsession_config.graph_options.rewrite_options.'
1383                     'disable_meta_optimizer = True')
-> 1384       raise type(e)(node_def, op, message)
1385 
1386   def _extend_graph(self):
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.

附言:这是之前一篇文章的编辑,我试图错误地评估tensorarray的值。

这听起来像是tf.while_loop中的数据类型不匹配。看看下面的工作代码。

import time
import tensorflow as tf
import numpy as np
#tf.compat.v1.disable_eager_execution()

#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)

x_batch = tf.convert_to_tensor(x_train)
print(type(x_batch))
s_pred_im = tf.convert_to_tensor(x_batch)
print(type(s_pred_im))
iters = tf.constant(10)
print(type(iters))
a = tf.TensorArray(tf.float32, size=10)
print(type(a))
def cond(value, a, s_pred_im, x_batch, i, iters):
return tf.less(i, iters)
def body(value, a, s_pred_im, x_batch, i, iters):
value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
a = a.write(i,value)
return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]
res = tf.while_loop(cond, body, [0.0, a, s_pred_im, x_batch, 0, iters])
b = res[1].stack()
sess = tf.compat.v1.Session()
with sess.as_default():
print(b.eval())

输出

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

相关内容

  • 没有找到相关文章

最新更新