如何在tensorflow中更新3个通道的像素值



我试图实现的目标是更新张量的所有3个RGB通道中的像素。例如,

t1 = tf.zeros([4, 4, 3], tf.int32)

t1是具有像素的4x4作为(x,y(和3个通道的张量。我用tf.where来寻找另一个形状相同的张量(如(中最高值的索引

t2 = tf.random.uniform([4, 4, 3], maxval=10, dtype=tf.int32)
tf.where(tf.math.reduce_max(t2), 1, t1)

但这只会改变t2中具有最高值的单个索引的值。我想要的是在所有3个频道中更新(x,y(。

例如,如果t2看起来像频道1->[[1 2],[3,4]]频道2->[[5 6],[7,8]]频道2->[[1 0],[0,0]]最大值为8,它是通道内的(1,1(索引。

我希望t1看起来像
Channel 1->[[0 0],[0,1]]频道2->[[0 0],[0,1]]频道2->[[0 0],[0,1]]

我怎样才能做到这一点?

我想这就行了。你需要减少到二维张量,然后将其平铺在通道维度上。

# Sample tensor for demo purposes with integers 0-7, the max will be in t2[1,1,1]
# Channel is dimension 2
t2 = tf.reshape([i for i in range(8)],(2,2,2))
print(f"{t2=}")
mask3 = tf.cast(t2 == tf.reduce_max(t2), tf.int32)     # 3-D integer mask, identifies the cell with max value
mask2 = tf.reduce_max(mask3, axis=2, keepdims=True)    # Reduce across channel dimension, find pixel with max value
t1 = tf.tile(mask2, (1, 1, t2.shape[2]))              # Tile to get 1 in target pixel across all channels
print(f"n{t1=}")

首先,我们将(部分(索引更新为

idx = tf.where(t2 == tf.reduce_max(t2))[..., :-1]

或者,如果可以有多个具有最大值的位置,并且您只想使用第一个,那么

idx = tf.where(t2 == tf.reduce_max(t2))[:1, ..., :-1]

有了部分索引,我们可以使用tf.tensor_scatter_nd_update更新t1

t1 = tf.tensor_scatter_nd_update(t1, idx, tf.ones((idx.shape[0],t1.shape[-1]), tf.int32))

这适用于具有两个或更多维度的任何张量;通道";维

最新更新