"(pred[:, 2:4] > min_wh).all(1)' 在 YOLO(深度学习)中做什么?



我在pytorch下使用yoloV3。我(pred[:, 2:4] > min_wh).all(1)遇到了这段代码,不知道它的功能。谁能帮忙?谢谢!

我担心的是使用().all(1).我知道.all().any(),但不知道.all(1)。请解释.all(1),谢谢。

根据文档 https://pytorch.org/docs/stable/tensors.html#torch.BoolTensor.all

第一个参数dimall(dim)。这意味着它与all()相同,但仅在所选维度上。它基本上用于选择宽度和高度大于min_wh的预测((。

在您的情况下,pred具有形状(number_of_predictions, 7)

[
[x, y, w, h, object_conf, class_conf, class],
[x, y, w, h, object_conf, class_conf, class],
...
]

pred[:, 2:4] > min_wh后,结果将是这样的

[
[True, False],
[True, True],
[False, False],
...
]

我们要选择宽度和高度都大于min_wh的行,因此我们需要使用all(1)

因为

all()会给你True是否所有元素都TrueFalse否则

all(0)会给你形状(2,)的张量,例如[True, False]。如果第一中的所有元素都True,则第一个元素将被True,否则False。如果第二中的所有元素都True,则第二个元素将被True,否则False

all(1)会给你形状(number_of_predictions,)的张量, 其中,仅当行中的所有元素都True时,才True每个元素。

相关内容

  • 没有找到相关文章

最新更新