BaseMetric¶
- class BaseMetric(gather_result=False)[源代码]¶
Metric 的基类。
- 参数:
gather_result (
bool, default:False) – 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。
- abstract update(result)[源代码]¶
- 参数:
result (
Dict) –经过 gather 后的输入。一般为如下格式的字典:
{ 'logits': [logit1, logit2, ..., logit_dp_size], 'labels': [label1, label2, ..., label_dp_size] }
其中
dp_size为 并行的卡数量
- gather(result)[源代码]¶
将不同进程上的 result 数据聚合在一起,使用了 DDP 情况。
- 参数:
result (
Dict[str,Tensor]) –:class Trainer 中 eval_fn 返回的结果。类型为 Dict[str, torch.Tensor]。 例如:
result = {'logits': logit, 'labels': label}
- 返回类型:
Dict[str,List]- 返回:
经过 gather 后的结果。类型为 Dict[str, torch.Tensor]。
当
dp_size不为 1 时 (即开启了数据并行的情况下), 会把不同 dp 进程的result按照第一个维度进行拼接。