Shortcuts

collie.metrics.accuracy 源代码

from typing import Dict, List, Optional

import numpy as np
import torch

from collie.log import logger
from collie.metrics.base import BaseMetric
from collie.utils.seq_len_to_mask import seq_len_to_mask


[文档]class AccuracyMetric(BaseMetric): """ 计算准确率的 metric :param gather_result: 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。 """ def __init__(self, gather_result: bool = False): super().__init__(gather_result=gather_result) self.correct = 0 self.total = 0
[文档] def reset(self): """ 重置参数 """ self.correct = 0 self.total = 0
[文档] def get_metric(self) -> Dict: r""" :meth:`get_metric` 函数将根据 :meth:`update` 函数累计的评价指标统计量来计算最终的评价结果。 :return: 字典形式的评测结果,例如:: {"acc": float, 'total': float, 'correct': float} """ evaluate_result = { "acc": round(self.correct / (self.total + 1e-12), 6), "total": self.total, "correct": self.correct, } return evaluate_result
[文档] def update(self, result: Dict): r""" :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。 :param result: 类型为 Dict 且 keys 至少包含["pred", "target"] * pred - 预测的 tensor, tensor 的形状可以是 ``torch.Size([B,])`` 、``torch.Size([B, n_classes])`` 、 ``torch.Size([B, max_len])`` 或 ``torch.Size([B, max_len, n_classes])`` * target - 真实值的 tensor, tensor 的形状可以是 ``torch.Size([B,])`` 、``torch.Size([B, max_len])`` 或 ``torch.Size([B, max_len])`` * seq_len - 序列长度标记, 标记的形状可以是 ``None``, 或者 ``torch.Size([B])`` 。 如果 mask 也被传进来的话 ``seq_len`` 会被忽略 """ assert ( "pred" in result and "target" in result ), "pred and target must in result, but they not." pred = result.get("pred") target = result.get("target") # ddp 时候需要手动 gahter 所有数据。 默认输入的类型都是 tensor if isinstance(pred, List): pred = torch.stack(pred, dim=0) if isinstance(target, List): target = torch.stack(target, dim=0) seq_len = None if "seq_len" in result: seq_len = result.get("seq_len") if seq_len is not None and target.dim() > 1: max_len = target.size(1) masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = None if pred.dim() == target.dim(): if torch.numel(pred) != torch.numel(target): raise RuntimeError( f"when pred have same dimensions with target, they should have same element numbers." f" while target have shape:{target.shape}, " f"pred have shape: {pred.shape}" ) pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) if seq_len is None and target.dim() > 1: logger.warning( "You are not passing `seq_len` to exclude pad when calculate accuracy." ) else: raise RuntimeError( f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " f"{pred.shape[:-1]}, got {target.shape}." ) if masks is not None: self.correct += torch.sum( torch.eq(pred, target).masked_fill(masks.eq(False), 0) ).item() self.total += torch.sum(masks).item() else: self.correct += torch.sum(torch.eq(pred, target)).item() self.total += np.prod(list(pred.size())).item()