我们不知道输入张量的形状,我们不应该使用任何循环。只有缩减和索引操作。我们如何将每一行的最小值设置为零?
例如:
输入:
x = torch.tensor([[
[10, 20, 30]
[2, 5, 1]
]])
输出:
torch.tensor([
[0, 20, 30],
[2, 5, 0]
])
我想不通,也找不到任何相关的问题。我被卡住了。
这样做的一种方法是计算行最小值x.min(dim=-1)
,获得最小值x.min(dim=-1).values
(在多个最小元素的情况下索引不起作用(,使用比较获得指示非最小元素位置的掩码并乘以它:
axis = -1 # Minimum iterating over the last dimension
min_values = x.min(dim=axis).values # Get minimal values
min_values_shape_corrected = min_values.unsqueeze(axis) # Reshape the minimal values so we can compare it with `x`
mask = (x != min_values_shape_corrected) # Get the mask of non-minimal elements
result = x * mask # Multiplying by a boolean mask leaves only True elements and sets False ones to 0
或者在一行中
x * (x != x.min(dim=axis).values.unsqueeze(axis))