Shortcuts

AccuracyMetric

class AccuracyMetric(gather_result=False)[源代码]

计算准确率的 metric

参数:

gather_result (bool, default: False) – 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。

reset()[源代码]

重置参数

get_metric()[源代码]

get_metric() 函数将根据 update() 函数累计的评价指标统计量来计算最终的评价结果。

返回类型:

Dict

返回:

字典形式的评测结果,例如:

{"acc": float, 'total': float, 'correct': float}

update(result)[源代码]

update() 函数将针对一个批次的预测结果做评价指标的累计。

参数:

result (Dict) –

类型为 Dict 且 keys 至少包含[“pred”, “target”]

  • pred - 预测的 tensor, tensor 的形状可以是 torch.Size([B,])torch.Size([B, n_classes])torch.Size([B, max_len])torch.Size([B, max_len, n_classes])

  • target - 真实值的 tensor, tensor 的形状可以是 torch.Size([B,])torch.Size([B, max_len])torch.Size([B, max_len])

  • seq_len - 序列长度标记, 标记的形状可以是 None, 或者 torch.Size([B]) 。 如果 mask 也被传进来的话 seq_len 会被忽略