是否有更简单的方法来获取张量的切片,如以下示例所示



我想对张量进行切片,例如numpy中的以下切片。我该怎么做?

# numpy array
a = np.reshape(np.arange(60), (3,2,2,5))
idx = np.array([0, 1, 0])
N = np.shape(a)[0]
mask = a[np.arange(N),:,:,idx]

# I have tried several solutions, but only the following success.
# tensors
import tensorflow as tf
import numpy as np

a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])
fn = lambda i: a[i][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(tf.shape(masks)))
    print(sess.run(masks))

是否有更简单的方法可以实现这一目标?

我可以使用功能tf.gathertf.gather_nd实现这一目标吗?非常感谢!

1. AROTH方法

我不确定这是最好的方法,但是它更快。您可以使用tf.boolean_mask代替tf.map_fn

import tensorflow as tf
import numpy as np
a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])
fn = lambda i: a[i,:,:][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
# new method
idx = tf.one_hot(idx2,depth=a.shape[-1])
masks2 = tf.boolean_mask(tf.transpose(a,[0,3,1,2]), idx)
with tf.Session() as sess:
    print('tf.map_fn version:n',sess.run(masks))
    print('tf.boolean_mask version:n',sess.run(masks2))
# print
tf.map_fn version:
 [[[ 0  5]
  [10 15]]
 [[21 26]
  [31 36]]
 [[40 45]
  [50 55]]]
tf.boolean_mask version:
 [[[ 0  5]
  [10 15]]
 [[21 26]
  [31 36]]
 [[40 45]
  [50 55]]]

2.绩效比较

矢量化方法1000迭代需要0.07s,而tf.map_fn方法1000迭代需要我的8GB GPU内存中的0.85s。矢量化方法的速度将比tf.map_fn()更快。

import datetime
...
with tf.Session() as sess:
    start = datetime.datetime.now()
    for _ in range(1000):
        sess.run(masks)
    end = datetime.datetime.now()
    print('tf.map_fn version cost time(seconds) : %.2f' % ((end - start).total_seconds()))
    start = datetime.datetime.now()
    for _ in range(1000):
        sess.run(masks2)
    end = datetime.datetime.now()
    print('tf.boolean_mask version cost time(seconds) : %.2f' % ((end - start).total_seconds()))
# print
tf.map_fn version cost time(seconds) : 0.85
tf.boolean_mask version cost time(seconds) : 0.07

我相信,随着a的形状增加,性能差异将变得更加明显。

另一种方法使用tf.gather_nd

import tensorflow as tf
import numpy as np

a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx = tf.range(tf.shape(a)[0])
idx2 = tf.constant([0,1,0])
indices = tf.stack([idx, idx2], axis=1)
a = tf.transpose(a, [0,3,1,2])
masks = tf.gather_nd(a, indices)
with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(tf.shape(masks)))
    print(sess.run(masks))

最新更新