PyTorch如何将张量每行的最小值设置为零



我们不知道输入张量的形状,我们不应该使用任何循环。只有缩减和索引操作。我们如何将每一行的最小值设置为零?

例如:

输入:

x = torch.tensor([[
      [10, 20, 30]
      [2, 5, 1]
    ]])

输出:

torch.tensor([
  [0, 20, 30],
  [2, 5, 0]
])

我想不通,也找不到任何相关的问题。我被卡住了。

这样做的一种方法是计算行最小值x.min(dim=-1),获得最小值x.min(dim=-1).values(在多个最小元素的情况下索引不起作用(,使用比较获得指示非最小元素位置的掩码并乘以它:

axis = -1  # Minimum iterating over the last dimension
min_values = x.min(dim=axis).values  # Get minimal values
min_values_shape_corrected = min_values.unsqueeze(axis)  # Reshape the minimal values so we can compare it with `x`
mask = (x != min_values_shape_corrected)  # Get the mask of non-minimal elements
result = x * mask  # Multiplying by a boolean mask leaves only True elements and sets False ones to 0

或者在一行中

x * (x != x.min(dim=axis).values.unsqueeze(axis))

最新更新