我是TensorFlow的初学者,我正在尝试实现一个将批处理作为输入的函数。它必须将此批处理切成几个,对它们应用一些操作,然后将它们连接起来以构建要返回的新张量。通过我的阅读,我发现有一些实现的功能,如input_slice_producer和batch_join但我没有使用它们。 我在下面附上了我发现的解决方案,但它有点慢,不正确,无法检测到批次的当前大小。有谁知道更好的方法吗?
def model(x):
W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
x_size = x.get_shape().as_list()[0]
# x is a batch of bigger input of shape [None,6], so I couldn't
# get the proper size of the batch when feeding it
if x_size == None:
x_size= batch_size
#intialize the y_res
dummy_x = tf.slice(x,[0,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.zeros([1], tf.float32)
y_res = result
#go throw all slices and concatenate them to get result
for i in range(1,x_size):
dummy_x = tf.slice(x,[i,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.concat(0, [y_res, result])
return y_res
TensorFlow函数tf.map_fn(fn, elems)
允许您将函数(fn
)应用于张量(elems
)的每个切片。例如,您可以按如下方式表达程序:
def model(x):
W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")
def fn(x_slice):
return tf.reduce_sum(x_slice, W_1)
return tf.map_fn(fn, x)
还可以使用tf.mul()
运算符上的广播更简洁地实现您的操作,该运算符使用 NumPy 广播语义,以及axis
参数来tf.reduce_sum()
。