Pytorch模型输出-只保持分数在0.3以上



我有一个大列表(实际上只有一个元素和3个字典),如下所示。它是一个测试集实例的预训练pytorch模型的输出。列表中有三个属性(框、标签、分数)),所有张量类型。每个盒子都有相应的分数和标签。总共有100个箱子。是否有快速的方法只保留分数大于0.3的盒子、标签和分数?所以在这个例子中,应该只有5个盒子,它们各自的分数和标签。

output = [{'boxes': tensor([[0.0000e+00, 2.9095e+01, 7.3249e+01, 1.1387e+02],
[7.8610e+01, 1.9392e+01, 1.6580e+02, 1.0291e+02],
[3.6086e-01, 2.9609e+01, 1.0292e+02, 2.0285e+02],
[1.8569e+02, 2.3418e+01, 2.4397e+02, 1.4092e+02],
[1.9678e-03, 0.0000e+00, 5.8328e+01, 1.7467e+02],
[1.4161e+02, 1.5196e+02, 2.2797e+02, 2.3690e+02],
[1.5630e+02, 5.4246e+01, 2.1178e+02, 1.7170e+02],
[5.3407e+01, 6.4962e+01, 1.0892e+02, 1.8180e+02],
[1.0011e+02, 1.5188e+02, 1.8732e+02, 2.3737e+02],
[1.5080e+02, 3.9219e+01, 2.3776e+02, 1.2494e+02],
[8.9806e+01, 1.3143e+02, 1.7610e+02, 2.1669e+02],
[1.3518e+02, 1.2713e+02, 1.9257e+02, 2.4350e+02],
[1.1423e+02, 1.4989e+01, 1.7153e+02, 1.3093e+02],
[7.9036e+01, 1.1927e+00, 1.9153e+02, 1.7694e+02],
[8.4356e+01, 2.3523e+01, 1.4035e+02, 1.4181e+02],
[6.9645e+01, 1.5251e+02, 1.5582e+02, 2.3697e+02],
[1.4163e+02, 1.2086e+02, 2.2753e+02, 2.0553e+02],
[8.3618e+01, 1.0583e+02, 1.4110e+02, 2.2334e+02],
[3.2450e-01, 7.1444e+01, 7.1565e+01, 1.5488e+02],
[7.2167e+00, 4.9198e+01, 9.3541e+01, 1.3515e+02],
[3.8690e+01, 3.7546e+01, 1.2640e+02, 1.2457e+02],
[1.0393e+02, 8.4865e+01, 1.6160e+02, 2.0193e+02],
[9.6637e+00, 1.2074e+02, 9.1465e+01, 2.0829e+02],
[2.6140e+00, 8.6522e+01, 5.9267e+01, 2.0357e+02],
[1.6260e+02, 6.0580e+01, 2.4744e+02, 1.4646e+02],
[1.7624e+02, 7.5614e+01, 2.3260e+02, 1.9287e+02],
[1.2096e+02, 2.8686e+01, 2.0757e+02, 1.1476e+02],
[1.0993e+02, 1.1107e+02, 1.9594e+02, 1.9697e+02],
[3.8821e+01, 1.6499e+00, 1.5277e+02, 1.7680e+02],
[1.3592e+02, 1.7006e+00, 2.5177e+02, 1.7528e+02],
[4.3270e+01, 8.5363e+01, 9.8090e+01, 2.0313e+02],
[3.9082e+01, 1.6281e+02, 1.2582e+02, 2.4565e+02],
[1.0941e+02, 9.5967e+00, 1.9760e+02, 9.2006e+01],
[9.4279e+01, 5.4012e+01, 1.5129e+02, 1.7273e+02],
[1.7610e+02, 1.1657e+02, 2.3257e+02, 2.3306e+02],
[1.8356e+02, 1.2347e+02, 2.5438e+02, 2.0546e+02],
[1.4145e+00, 7.8904e+01, 7.9495e+01, 2.5600e+02],
[5.7602e+01, 9.8933e+01, 1.7596e+02, 2.5600e+02],
[1.3184e+02, 9.0243e+01, 2.1412e+02, 1.7617e+02],
[1.3507e+02, 4.4525e+01, 1.9094e+02, 1.6303e+02],
[1.0465e+01, 1.5353e+02, 9.2553e+01, 2.3779e+02],
[1.8336e+02, 1.5361e+02, 2.5428e+02, 2.3656e+02],
[1.9591e+02, 7.6679e+01, 2.5259e+02, 1.9309e+02],
[9.8446e+01, 7.9284e+01, 2.1062e+02, 2.5600e+02],
[1.3960e+02, 9.7938e+00, 2.2880e+02, 9.1881e+01],
[9.0553e+01, 0.0000e+00, 1.7578e+02, 7.1238e+01],
[1.8702e+00, 1.2331e+02, 4.8407e+01, 2.4824e+02],
[0.0000e+00, 4.9428e-01, 1.2664e+02, 1.4013e+02],
[7.9054e+01, 1.5236e+00, 2.3079e+02, 1.0113e+02],
[1.6006e+02, 6.4527e+01, 2.5467e+02, 2.5600e+02],
[0.0000e+00, 1.7543e+02, 1.8282e+02, 2.5560e+02],
[1.7264e+00, 1.7961e+02, 7.0644e+01, 2.5600e+02],
[1.8063e+02, 9.7504e+00, 2.5516e+02, 9.3108e+01],
[5.0636e+01, 1.3299e+02, 1.3221e+02, 2.1604e+02],
[3.1850e+01, 5.4289e+01, 8.8847e+01, 1.7312e+02],
[2.0640e+02, 3.2171e+01, 2.5485e+02, 1.5160e+02],
[1.9062e+01, 1.2459e+00, 1.7152e+02, 1.0271e+02],
[8.0108e+01, 1.8195e+02, 1.6510e+02, 2.5523e+02],
[6.4087e+00, 1.3263e+00, 9.6138e+01, 7.0220e+01],
[2.1170e+01, 2.3619e+00, 7.9999e+01, 1.0748e+02],
[5.7921e+01, 6.5922e-01, 1.4736e+02, 8.1241e+01],
[1.1025e+02, 1.8136e+02, 1.9421e+02, 2.5513e+02],
[6.1567e+01, 1.7640e+02, 2.5535e+02, 2.5600e+02],
[3.9355e+01, 3.4047e+00, 1.0319e+02, 8.7065e+01],
[5.0878e+01, 1.0217e+02, 1.3312e+02, 1.8664e+02],
[7.4605e+01, 5.4398e+01, 1.2882e+02, 1.7320e+02],
[1.7292e+02, 1.8397e+02, 2.5249e+02, 2.5558e+02],
[1.8037e+01, 9.5900e+01, 1.3505e+02, 2.5211e+02],
[1.4013e+02, 1.9098e+02, 2.2696e+02, 2.5415e+02],
[6.2275e+01, 8.4387e+01, 9.0523e+01, 1.4154e+02],
[1.7307e+01, 1.9287e+02, 1.1007e+02, 2.5532e+02],
[7.2651e+01, 6.8909e+01, 1.0101e+02, 1.2587e+02],
[1.6461e+02, 7.4065e+01, 1.9301e+02, 1.3071e+02],
[7.7585e+01, 5.3052e+01, 1.0652e+02, 1.0954e+02],
[1.6948e+02, 6.2818e+01, 1.9857e+02, 1.2058e+02],
[5.7015e+01, 1.0044e+02, 8.5367e+01, 1.5666e+02],
[7.7270e+01, 8.4819e+01, 1.0621e+02, 1.4148e+02],
[6.7998e-01, 5.9344e+01, 3.2433e+01, 1.0408e+02],
[5.7399e+01, 5.8315e+01, 8.5744e+01, 1.1517e+02],
[1.5450e+02, 5.2688e+01, 1.8301e+02, 1.1024e+02],
[6.7396e+01, 5.3385e+01, 9.6198e+01, 1.0940e+02],
[5.2431e+01, 1.1456e+02, 8.0534e+01, 1.7151e+02],
[5.2424e+01, 7.2901e+01, 8.0602e+01, 1.3118e+02],
[5.4905e+01, 7.4931e+01, 9.8882e+01, 1.1913e+02],
[1.7986e+02, 6.2481e+01, 2.0913e+02, 1.2018e+02],
[6.7338e+01, 1.0491e+02, 9.5506e+01, 1.6148e+02],
[1.7451e+02, 8.3800e+01, 2.0362e+02, 1.4150e+02],
[4.9071e+01, 4.8894e+01, 9.4031e+01, 9.3116e+01],
[1.0840e+00, 1.1331e+01, 3.7614e+01, 1.2899e+02],
[8.2344e+01, 6.8916e+01, 1.1157e+02, 1.2634e+02],
[1.6138e+02, 5.9034e+01, 2.0742e+02, 1.0366e+02],
[1.4473e+01, 2.2719e+01, 2.1446e+02, 1.3569e+02],
[4.7128e+01, 8.9411e+01, 7.5240e+01, 1.4684e+02],
[1.8501e+02, 1.1383e+02, 2.1348e+02, 1.7167e+02],
[5.4657e+01, 1.2145e+02, 9.8843e+01, 1.6540e+02],
[3.7407e+00, 2.7927e+01, 4.8678e+01, 7.2360e+01],
[1.6647e+02, 8.0292e+01, 2.1177e+02, 1.2410e+02],
[3.4396e+01, 6.4177e+01, 7.8244e+01, 1.0873e+02],
[8.7888e+01, 5.2212e+01, 1.1652e+02, 1.1079e+02],
[1.7443e+02, 1.1404e+02, 2.0353e+02, 1.7198e+02]], device='cuda:0'),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1], device='cuda:0'),
'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108, 0.2977, 0.2974, 0.2955, 0.2938,
0.2904, 0.2902, 0.2836, 0.2797, 0.2794, 0.2787, 0.2784, 0.2782, 0.2746,
0.2739, 0.2695, 0.2658, 0.2655, 0.2647, 0.2628, 0.2622, 0.2596, 0.2596,
0.2593, 0.2591, 0.2584, 0.2574, 0.2559, 0.2550, 0.2529, 0.2526, 0.2429,
0.2428, 0.2408, 0.2397, 0.2381, 0.2370, 0.2344, 0.2302, 0.2296, 0.2292,
0.2260, 0.2258, 0.2252, 0.2201, 0.2166, 0.2125, 0.2063, 0.2056, 0.2054,
0.2050, 0.2032, 0.2023, 0.2021, 0.1985, 0.1956, 0.1943, 0.1776, 0.1739,
0.1708, 0.1700, 0.1665, 0.1657, 0.1595, 0.1588, 0.1561, 0.1553, 0.1553,
0.1484, 0.1426, 0.1419, 0.1416, 0.1289, 0.1265, 0.1250, 0.1248, 0.1226,
0.1219, 0.1216, 0.1208, 0.1197, 0.1186, 0.1182, 0.1164, 0.1164, 0.1157,
0.1133, 0.1109, 0.1097, 0.1086, 0.1055, 0.1055, 0.1054, 0.1047, 0.1026,
0.1020], device='cuda:0')}]

所以你在寻找这样的东西?

mask = output[0]['scores'] > 0.3
for key,val in output[0].items():
output[0][key] = val[mask]
output[0]
{'boxes': tensor([[0.0000e+00, 2.9095e+01, 7.3249e+01, 1.1387e+02],
[7.8610e+01, 1.9392e+01, 1.6580e+02, 1.0291e+02],
[3.6086e-01, 2.9609e+01, 1.0292e+02, 2.0285e+02],
[1.8569e+02, 2.3418e+01, 2.4397e+02, 1.4092e+02],
[1.9678e-03, 0.0000e+00, 5.8328e+01, 1.7467e+02]]),
'labels': tensor([1, 1, 1, 1, 1]),
'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108])}

您可以在字典上循环并对三个张量应用掩码:

result = []
for d in output:
boxes, labels, scores = d['boxes'], d['labels'], d['scores']
m = scores > .3
result.append(dict(boxes=boxes[m], labels=labels[m], scores=scores[m]))

或者使用字典理解:

result = []
for d in output:
m = d['scores'] > .3
result.append({k: v[m] for k, v in d.items()})

你将得到:

>>> result
[{'boxes': tensor([[0.0000e+00, 2.9095e+01, 7.3249e+01, 1.1387e+02],
[7.8610e+01, 1.9392e+01, 1.6580e+02, 1.0291e+02],
[3.6086e-01, 2.9609e+01, 1.0292e+02, 2.0285e+02],
[1.8569e+02, 2.3418e+01, 2.4397e+02, 1.4092e+02],
[1.9678e-03, 0.0000e+00, 5.8328e+01, 1.7467e+02]]),
'labels': tensor([1, 1, 1, 1, 1]),
'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108])}]

相关内容

  • 没有找到相关文章

最新更新