如何从 mxnet nd 数组访问类标签的索引?输出必须是使用函数输入的标签名称的索引


#pred_probas is probabilities of each class type
def find_class_idx(label):
"""
Should return the class index of a particular label.
:param label: label of class
:type label: str
:return: class index
:rtype: int
"""
#ind = mx.nd.argmax(label, axis=1).astype('int')
topk_indices=mx.nd.topk(pred_probas,k=100)
return max(topk_indices)*100

因为 GluonCV 的网络输出类是列表类型的。我们可以使用此函数 list.index(label( 访问列表的索引

def find_class_idx(label):
"""
Should return the class index of a particular label.
:param label: label of class
:type label: str
:return: class index
:rtype: int
"""
return network.classes.index(label)

最新更新