我在pytorch下使用yoloV3。我(pred[:, 2:4] > min_wh).all(1)
遇到了这段代码,不知道它的功能。谁能帮忙?谢谢!
我担心的是使用().all(1)
.我知道.all()
或.any()
,但不知道.all(1)
。请解释.all(1)
,谢谢。
根据文档 https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all
第一个参数dim
有all(dim)
。这意味着它与all()
相同,但仅在所选维度上。它基本上用于选择宽度和高度都大于min_wh
的预测(行(。
在您的情况下,pred
具有形状(number_of_predictions, 7)
或
[
[x, y, w, h, object_conf, class_conf, class],
[x, y, w, h, object_conf, class_conf, class],
...
]
pred[:, 2:4] > min_wh
后,结果将是这样的
[
[True, False],
[True, True],
[False, False],
...
]
我们要选择宽度和高度都大于min_wh
的行,因此我们需要使用all(1)
。
因为
all()
会给你True
是否所有元素都True
,False
否则
all(0)
会给你形状(2,)
的张量,例如[True, False]
。如果第一列中的所有元素都True
,则第一个元素将被True
,否则False
。如果第二列中的所有元素都True
,则第二个元素将被True
,否则False
。
all(1)
会给你形状(number_of_predictions,)
的张量, 其中,仅当行中的所有元素都True
时,才True
每个元素。