考虑这个张量
a = tf.constant([0,1,2,3,5,6,7,8,9,10,19,20,21,22,23,24])
我想把它分成3个张量(对于这个特定的例子),包含数字紧相邻的群。预期的输出将是:
output_tensor = [ [0,1,2,3], [5,6,7,8,9,10], [19,20,21,22,23,24] ]
你知道怎么做吗?有没有一种张量流数学方法可以帮助你有效地做到这一点?我什么也找不到。
对于提供的示例,分割应该可以工作:
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
print(tf.split(a, [4, 6, 6]))
输出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5, 6, 7, 8, 9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]
第二个参数决定了每个输出张量沿着分裂轴的大小(默认为0)——所以在这种情况下,第一个张量的大小为4,第二个张量的大小为6,第三个张量的大小为6。或者,也可以提供一个int,只要你所分割的轴上的张量的大小可以被这个值整除。在这种情况下,3不能工作(16/3 = 5.3333),但4可以:
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
print(tf.split(a, 4))
输出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([5, 6, 7, 8], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 9, 10, 19, 20], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([21, 22, 23, 24], dtype=int32)>]
假设数字连续的边界是未知的,指数可以使用相邻差有效地计算并提供给tf.split
:
def compute_split_indices(x):
adjacent_diffs = x[1:] - x[:-1] # compute adjacent differences
indices_where_not_continuous = tf.where(adjacent_diffs > 1) + 1
splits = tf.concat([indices_where_not_continuous[:1], indices_where_not_continuous[1:] -
indices_where_not_continuous[:-1]], axis=0) # compute split sizes from the indices
splits_as_ints = [split.numpy().tolist()[0] for split in splits] # convert to a list of integers for ease of use
final_split_sizes = splits_as_ints + [len(x) - sum(splits_as_ints)] # account for the rest of the tensor
return final_split_sizes
if __name__ == "__main__":
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
splits = compute_split_indices(a)
print(tf.split(a, splits))
输出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5, 6, 7, 8, 9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]
请注意,输出结果与显式提供[4, 6, 6]
时相同。