用于图像特征提取的张量流多处理



我有一些基本的函数,可以接收图像的URL,并通过VGG-16 CNN对其进行转换:

def convert_url(_id, url):   
im = get_image(url)
return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}))

我有一大组 URL (~60,000(,我想在其上执行此功能。每次迭代花费的时间都超过一秒,这太慢了。我想通过并行使用多个进程来加快速度。没有需要担心的共享状态,因此多线程的常见陷阱不是问题。

但是,我不确定如何实际让 tensorflow 与多处理包一起使用。我知道您不能将张量流session传递给池变量。因此,我尝试初始化session的多个实例:

def init():
global sess;
sess = tf.Session()

但是当我实际启动该过程时,它只是无限期挂起:

with Pool(processes=3,initializer=init) as pool:
results = pool.starmap(convert_url, list(id_img_dict.items())[0:5])

请注意,张量流图是全局定义的。我认为这是正确的方法,但我不确定:

input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
_, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)

谁能帮我解决这个问题?感谢。

忘记python的普通多线程工具,使用tensorflow.contrib.data.Dataset。尝试类似下面的操作。

urls = ['img1.jpg', 'img2.jpg', ...]
batch_size = 16
n_batches = len(urls) // batch_size  # do something more elegant for remainder

def load_img(url):
image = tf.read_file(url, name='image_data')
image = tf.image.decode_jpeg(image, channels=3, name='image')
return image

def preprocess(img_tensor):
img_tensor = (tf.cast(img_tensor, tf.float32) / 255 - 0.5)*2
img_tensor.set_shape((256, 256, 3))  # whatever shape
return img_tensor

dataset = tf.contrib.data.Dataset.from_tensor_slices(urls)
dataset = dataset.map(load_img).map(preprocess)
preprocessed_images = dataset.batch(
batch_size).make_one_shot_iterator().get_next()

arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
_, end_points = vgg_16(preprocessed_images, is_training=False)
output = end_points['vgg_16/fc7']

results = []
with tf.Session() as sess:
tf.train.Saver().restore(sess, checkpoint_file)
for i in range(n_batches):
batch_results = sess.run(output)
results.extend(batch_results)
print('Done batch %d / %d' % (i+1, n_batches))

最新更新