collie.callbacks.has_monitor_callback 源代码
import functools
from abc import ABC
from typing import Any, Callable, Dict, Optional, Union
from .callback import Callback
from .utils import _get_monitor_value
from collie.log import logger
from collie.utils import apply_to_collection
from collie.utils.utils import _check_valid_parameters_number
__all__ = ['HasMonitorCallback', 'ResultsMonitor']
class CanItemDataType(ABC):
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class ResultsMonitor:
r"""监控某个数值并评估结果是否有所改善的监视器。
可用于监控某个数值,并通过 :meth:`is_better_results` 等接口检测结果是否变得
更好。
:param monitor: 监控的 metric 值。
* 为 ``None`` 时,不设置监控值。
* 为 ``str`` 时,
CoLLiE 将尝试直接使用该名称从 ``evaluation`` 的结果中寻找,如果最终在
``evaluation`` 结果中没有找到完全一致的名称,则将使用最长公共字符串算法
从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor``。
* 为 :class:`Callable` 时,
则接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作
为 ``monitor`` 的结果,如果当前结果中没有相关的 ``monitor`` 值则返回
``None``。
:param larger_better: monitor 是否为越大越好;
"""
def __init__(self,
monitor: Optional[Union[str, Callable]],
larger_better: bool = True):
self.set_monitor(monitor, larger_better)
self._log_name = self.__class__.__name__
def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(
monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
if self.monitor is not None:
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor
def itemize_results(self, results):
r"""执行结果中所有对象的 :meth:`item` 方法(如果没有则忽略),使得 Tensor
类型的数据转为 python 内置类型。
:param results:
:return:
"""
return apply_to_collection(
results, dtype=CanItemDataType, function=lambda x: x.item())
def get_monitor_value(self, results: Dict) -> Union[float, None]:
r"""获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 **最长公共字符
串算法** 匹配的方式寻找。
:param results: 评测结果;
:return: monitor 的值;如果为 ``None``,表明此次没有找到合适的monitor;
"""
if len(results) == 0 or self.monitor is None:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = self.itemize_results(results)
use_monitor, monitor_value = _get_monitor_value(
monitor=self.monitor, real_monitor=self._real_monitor, res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and \
self._real_monitor == self.monitor and \
use_monitor != self.monitor:
logger.rank_zero_warning(
f'We can not find monitor:`{self.monitor}` for '
f'`{self.log_name}` in the evaluation result (with keys as '
f'{list(results.keys())}), we use the `{use_monitor}` as the '
'monitor.',
once=True)
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and \
self._real_monitor != self.monitor and \
use_monitor != self._real_monitor:
logger.rank_zero_warning(
f'Change of monitor detected for `{self.log_name}`. '
f'The expected monitor is:`{self.monitor}`, '
f'last used monitor is:`{self._real_monitor}` '
f'and current monitor is:`{use_monitor}`. '
'Please consider using a customized monitor function when the '
'evaluation results are varying between validation.')
self._real_monitor = use_monitor
return monitor_value
def is_better_monitor_value(self,
monitor_value: float,
keep_if_better=True):
"""检测 ``monitor_value`` 是否是更好的。
:param monitor_value: 待检查的 ``monitor_value``。如果为 ``None``,返
回 False;
:param keep_if_better: 如果传入的 ``monitor_value`` 值更好,则将其保存下
来;
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value,
self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better
def is_better_results(self, results, keep_if_better=True):
r"""检测给定的 ``results`` 是否比上一次更好,如果本次 results 中没有找到相
关的 monitor 返回``False``。
:param results: evaluation 结果;
:param keep_if_better: 如果传入的 ``monitor_value`` 值更好,则将其保存下
来;
:return:
"""
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return False
return self.is_better_monitor_value(
monitor_value, keep_if_better=keep_if_better)
def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""传入的两个值中,是否 ``monitor_value1`` 的结果更好。
:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better
@property
def monitor_name(self):
r"""返回 monitor 的名字,如果 monitor 是个 Callable 的函数,则返回该函数的
名称。
:return:
"""
if callable(self.monitor):
try:
monitor = self.monitor
while isinstance(monitor, functools.partial):
monitor = monitor.func
monitor_name = monitor.__qualname__
except Exception:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候
# real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name
@property
def log_name(self) -> str:
"""内部用于打印当前类别信息使用。
:return:
"""
return self._log_name
@log_name.setter
def log_name(self, value):
self._log_name = value
[文档]class HasMonitorCallback(ResultsMonitor, Callback):
r"""对特定数值进行监控的 ``Callback``。
该 callback 不直接使用,作为其它相关 callback 的父类使用,如果 callback
有使用 monitor 可以继承该 ``Callback``。其已经实现了下面的功能:
1. 判断 monitor 合法性;
2. 在需要时,根据 trainer 的 monitor 设置自己的 monitor 名称。
:param monitor: 监控的 metric 值:
* 为 ``None`` 时,不设置监控值。
* 为 ``str`` 时,
CoLLiE 将尝试直接使用该名称从 ``evaluation`` 的结果中寻找,如果最终在
``evaluation`` 结果中没有找到完全一致的名称,则将使用最长公共字符串算法
从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor``。
* 为 :class:`Callable` 时,
则接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作
为 ``monitor`` 的结果,如果当前结果中没有相关的 ``monitor`` 值则返回
``None``。
:param larger_better: monitor 是否为越大越好;
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为
``True``,且没检测到设置 monitor 会报错;
"""
def __init__(self, monitor, larger_better, must_have_monitor=False):
super().__init__(monitor, larger_better)
self.must_have_monitor = must_have_monitor
[文档] def on_after_trainer_initialized(self, trainer):
r"""对于必须要有 monitor 设置的 callback ,该函数会进行检查。
:param trainer:
:return:
"""
if self.must_have_monitor and self.monitor is None:
raise RuntimeError(
f'No `monitor` is set for {self.log_name}. '
f'You can set it in the initialization or through Trainer.')