我有一个TensorArray(a(来存储tf.while_oop中计算的值。但是,我无法将TensorArray转换为Numpy数组。出于某种原因,int32和float32之间似乎不匹配。
import time
import tensorflow as tf
import numpy as np
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
x_batch = tf.convert_to_tensor(x_train)
s_pred_im = tf.convert_to_tensor(x_batch)
iters = tf.constant(10)
a = tf.TensorArray(tf.float32, size=10)
def cond(value, a, s_pred_im, x_batch, i, iters):
return tf.less(i, iters)
def body(value, a, s_pred_im, x_batch, i, iters):
value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
a = a.write(i,value)
return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]
res = tf.while_loop(cond, body, [0, a, s_pred_im, x_batch, 0, iters])
b = res[1].stack()
with tf.Session() as sess:
b.eval()
这样做会产生以下错误-
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1364 try:
-> 1365 return fn(*args)
1366 except errors.OpError as e:
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1347 # Ensure any changes to the graph are reflected in the runtime.
-> 1348 self._extend_graph()
1349 return self._call_tf_sessionrun(options, feed_dict, fetch_list,
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _extend_graph(self)
1387 with self._graph._session_run_lock(): # pylint: disable=protected-access
-> 1388 tf_session.ExtendSession(self._session)
1389
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-72-5642d29d3bf6> in <module>
1 with tf.Session() as sess:
----> 2 b.eval()
~anaconda3envstestlibsite-packagestensorflow_corepythonframeworkops.py in eval(self, feed_dict, session)
796
797 """
--> 798 return _eval_using_default_session(self, feed_dict, self.graph, session)
799
800 def experimental_ref(self):
~anaconda3envstestlibsite-packagestensorflow_corepythonframeworkops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
5405 "the tensor's graph is different from the session's "
5406 "graph.")
-> 5407 return session.run(tensors, feed_dict)
5408
5409
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in run(self, fetches, feed_dict, options, run_metadata)
954 try:
955 result = self._run(None, fetches, feed_dict, options_ptr,
--> 956 run_metadata_ptr)
957 if run_metadata:
958 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1178 if final_fetches or final_targets or (handle and feed_dict_tensor):
1179 results = self._do_run(handle, final_targets, final_fetches,
-> 1180 feed_dict_tensor, options, run_metadata)
1181 else:
1182 results = []
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1357 if handle is None:
1358 return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359 run_metadata)
1360 else:
1361 return self._do_call(_prun_fn, handle, feeds, fetches)
~anaconda3envstestlibsite-packagestensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1382 'nsession_config.graph_options.rewrite_options.'
1383 'disable_meta_optimizer = True')
-> 1384 raise type(e)(node_def, op, message)
1385
1386 def _extend_graph(self):
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.
附言:这是之前一篇文章的编辑,我试图错误地评估tensorarray的值。
这听起来像是tf.while_loop
中的数据类型不匹配。看看下面的工作代码。
import time
import tensorflow as tf
import numpy as np
#tf.compat.v1.disable_eager_execution()
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
x_batch = tf.convert_to_tensor(x_train)
print(type(x_batch))
s_pred_im = tf.convert_to_tensor(x_batch)
print(type(s_pred_im))
iters = tf.constant(10)
print(type(iters))
a = tf.TensorArray(tf.float32, size=10)
print(type(a))
def cond(value, a, s_pred_im, x_batch, i, iters):
return tf.less(i, iters)
def body(value, a, s_pred_im, x_batch, i, iters):
value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
a = a.write(i,value)
return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]
res = tf.while_loop(cond, body, [0.0, a, s_pred_im, x_batch, 0, iters])
b = res[1].stack()
sess = tf.compat.v1.Session()
with sess.as_default():
print(b.eval())
输出
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]