从 tf 记录读取数据时,Tensorflow 在会话运行时冻结



这是代码:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
import numpy as np
from scipy.misc import imread
import glob

with open("./labels_510.txt") as f:
lines = list(f.readlines())
labels = [str(w).replace("n", "") for w in lines]
NCLASS = len(labels)
NCHANNEL = 3
WIDTH = 224
HEIGHT = 224
def getImageBatch(filenames, batch_size, capacity, min_after_dequeue):
filenameQ = tf.train.string_input_producer(filenames, num_epochs=None)
recordReader = tf.TFRecordReader()
key, fullExample = recordReader.read(filenameQ)
key_val = sess.run(key)
print(key_val)
features = tf.parse_single_example(
fullExample,
features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
})
label = features['image/class/label']
image_buffer = features['image/encoded']
with tf.name_scope('decode_jpeg', [image_buffer], None):
image = tf.image.decode_jpeg(image_buffer, channels=NCHANNEL)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.reshape(1 - tf.image.rgb_to_grayscale(image), [WIDTH * HEIGHT * NCHANNEL])
label = tf.stack(tf.one_hot(label - 1, NCLASS))
imageBatch, labelBatch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
print(imageBatch.shape)
print(labelBatch.shape)
return imageBatch, labelBatch

with gfile.FastGFile("./output_graph_510.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def)
tf.global_variables_initializer().run()
image_tensor, label_batch = getImageBatch(glob.glob("./images/tf_records/validation*"), 1, 10, 2)
image_tensor = tf.reshape(image_tensor, (1, WIDTH, HEIGHT, NCHANNEL))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_data = sess.run(image_tensor)
# print(image_data.shape)
# softmax_tensor = sess.graph.get_tensor_by_name('import/final_result:0')
# predictions = sess.run(softmax_tensor, {'import/input:0': image_data})
# predictions = np.squeeze(predictions)
# print(predictions)
coord.request_stop()
coord.join(threads)

当我运行它时,它会冻结并显示以下消息:

2017-08-17 12:33:10.235086: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235099: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235101: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235104: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235106: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.322321: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:893] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2017-08-17 12:33:10.322510: I tensorflow/core/common_runtime/gpu/gpu_device.cc:940] Found device 0 with properties: 
name: GeForce GTX 1050
major: 6 minor: 1 memoryClockRate (GHz) 1.493
pciBusID 0000:01:00.0
Total memory: 3.95GiB
Free memory: 2.23GiB
2017-08-17 12:33:10.322519: I tensorflow/core/common_runtime/gpu/gpu_device.cc:961] DMA: 0 
2017-08-17 12:33:10.322522: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 0:   Y 
2017-08-17 12:33:10.322529: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1050, pci bus id: 0000:01:00.0)
  • 张量流版本:1.2.1
  • 乌班图16.04
  • GeForce GTX 1050

完整的项目可以在这里找到: https://github.com/kindlychung/demo-load-pb-tensorflow

因此它会冻结,因为您没有初始化与tf.train.shuffle_batch中使用的队列关联的局部变量。局部变量通常是为enqueuedequeue等操作创建的临时变量,用于跟踪元素。

...
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_data = sess.run(image_tensor)
print(image_data.shape)
...

最新更新