在pytorch点燃自定义度量中使用f1分数sklearn



我想在PyTorch点火的自定义度量中使用sklearn的f1_score
我找不到一个好的解决方案。尽管在PyTorch的官方网站上,有一个的解决方案

precision = Precision(average=False)
recall = Recall(average=False)
F1 = Fbeta(beta=1.0, average=False, precision=precision, recall=recall)

,如果你需要一个f1分数微观/宏观/加权,你不能使用这个例子。

如何在sklearn库中使用自定义度量?

解决方案是首先创建一个自定义度量:

import torch
from ignite.metrics import Metric
from sklearn.metrics import f1_score

class F1Score(Metric):
def __init__(self, *args, **kwargs):
self.f1 = 0
self.count = 0
super().__init__(*args, **kwargs)
def update(self, output):
y_pred, y = output[0].detach(), output[1].detach()
_, predicted = torch.max(y_pred, 1)
f = f1_score(y.cpu(), predicted.cpu(), average='micro')
self.f1 += f
self.count += 1
def reset(self):
self.f1 = 0
self.count = 0
super(F1Score, self).reset()
def compute(self):
return self.f1 / self.count

则可以在create_supervised_evaluatorcreate_supervised_trainer中将其用作:

import logging
import torch
from ignite.engine import Events
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy, Fbeta
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
from metrics.f1score import F1Score

def inference(
cfg,
model,
val_loader
):
device = cfg.MODEL.DEVICE
logger = logging.getLogger("template_model.inference")
logger.info("Start inferencing")
precision = Precision(average=False)
recall = Recall(average=False)
F1 = Fbeta(beta=1.0, average=False, precision=precision, recall=recall)
metrics = {'accuracy': Accuracy(),
'precision': precision,
'recall': recall,
'custom': F1Score(),
'f1': F1}
evaluator = create_supervised_evaluator(model,
metrics=metrics,
device=device)
# adding handlers using `evaluator.on` decorator API
@evaluator.on(Events.EPOCH_COMPLETED)
def print_validation_results(engine):
metrics = evaluator.state.metrics
metrics = evaluator.state.metrics
_avg_accuracy = metrics['accuracy']
_precision = metrics['precision']
_precision = torch.mean(_precision)
_recall = metrics['recall']
_recall = torch.mean(_recall)
_f1 = metrics['f1']
_f1 = torch.mean(_f1)
_custom = metrics['custom']
logger.info(
"Test Results - Epoch: {} Avg accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1 score: {:.3f}, custom: {:.2f}".format(
engine.state.epoch, _avg_accuracy, _precision, _recall, _f1, _custom))
evaluator.run(val_loader)

结果是:

Test Results - Epoch: 1 Avg accuracy: 0.758, precision: 0.776, recall: 0.766, f1 score: 0.759, custom: 0.76

或者,您可以使用点火器Precision和Recall实现,如下所示:

from ignite.metrics import Precision, Recall
precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()

然后将F1添加到度量字典中。

相关内容

  • 没有找到相关文章

最新更新