collie.callbacks.checkpoint_callback 源代码
import sys
from pathlib import Path
from typing import Callable, Dict, Optional, Union
from collie.log.logger import logger
from .callback import Callback
from .topk_saver import TopkSaver
__all__ = ['CheckpointCallback']
[文档]class CheckpointCallback(Callback):
r"""用于保存断点 ``checkpoint`` 的 ``Callback``。
其保存的文件目录以及文件名命名规则如下::
- folder/
- epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型
- epoch_{epoch_idx}-batch_{batch_idx}/ # 满足 every_n_batches 保存的模型
- last/ # 最后一个 epoch 的保存
- epoch_{epoch_idx}-batch_{batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
默认情况下,本 checkpoint 只保存了 model 的状态;如还需保存 Trainer 的状态
以断点重训的话,请使用 ``model_only=False``。
:param folder: 保存的文件夹,如果为 ``None`` ,默认使用当前文件夹。
:param every_n_epochs: 多少个 epoch 保存一次。
:param every_n_batches: 多少个 batch 保存一次。
:param process_exclusion: -- 是否互斥地执行保存操作;在模型规模较大时该参数可以
节省一定的内存。
:param model_only: 是否仅保存模型的权重;如果为 ``True`` 则仅会保存模型权重,
否则还会额外保存 optimizer、训练步数等断点信息以用于断点重训,可以通过
:meth:`.Trainer.load_checkpoint` 加载重新进行训练。该保存路径还可以通过
:meth:`.CollieForCausalLM.from_pretrained` 函数或者 :meth:`.Trainer.\
load_model` 加载到模型中;同时也可以直接加载到对应的 huggingface 模型中。
:param peft_only: 是否只保存 adapter;当未使用 ``peft`` 时该项无效
:param monitor: 监控的 metric 值。
* 为 ``str`` 时,
collie 将尝试直接使用该名称从 ``evaluation`` 的结果中寻找,如果最终在
``evaluation`` 结果中没有找到完全一致的名称,则将使用最长公共字符串算法
从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor``。
* 为 :class:`Callable` 时,
则接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作
为 ``monitor`` 的结果,如果当前结果中没有相关的 ``monitor`` 值则返回
``None``。
:param larger_better: monitor 的值是否时越大越好。
:param topk: 保存 monitor 结果中的 ``topk`` 个。
:param last: 如果为 ``True``,将在每次 epoch 运行结束都保存一次,会覆盖之前的
保存。如果为 ``False`` 则不会保存 ``last`` 文件。
:param max: 最多保留多少个通过 ``every_n_batches`` 和 ``every_n_epochs`` 保存
的权重(如果设置了的话);如果为 ``None`` 或 0,则会保留所有的权重文件。
:param kwargs: 传给 :meth:`.Trainer.save_checkpoint` 或者 :meth:`.Trainer.\
save_model` 、 :meth:`.Trainer.save_peft` 的额外参数。
"""
def __init__(
self,
folder: Optional[Union[str, Path]] = None,
every_n_epochs: Optional[int] = None,
every_n_batches: Optional[int] = None,
process_exclusion: bool = False,
model_only: bool = True,
peft_only: bool = True,
monitor: Optional[Union[str, Callable]] = None,
larger_better: bool = True,
topk: int = 0,
last: bool = False,
max: Optional[int] = None,
**kwargs):
super().__init__()
if every_n_epochs is not None:
if not isinstance(every_n_epochs, int) or every_n_epochs < 0:
raise ValueError(
'Parameter `every_n_epochs` should be an int and greater '
'than or equal to 0.')
if every_n_epochs is None or every_n_epochs == 0:
every_n_epochs = sys.maxsize # 使得没有数字可以整除
if every_n_batches is not None:
if not isinstance(every_n_batches, int) or every_n_batches < 0:
raise ValueError(
'Parameter `every_n_batches` should be an int and greater '
'than or equal to 0.')
if every_n_batches is None or every_n_batches == 0:
every_n_batches = sys.maxsize
if max is not None:
if not isinstance(max, int) and max < 0:
raise ValueError(
'Parameter `max` should be an int and greater than or '
'equal to 0.')
self.topk_saver = TopkSaver(
topk=topk,
monitor=monitor,
larger_better=larger_better,
folder=folder,
process_exclusion=process_exclusion,
model_only=model_only,
peft_only=peft_only,
**kwargs)
self.topk_saver.log_name = self.__class__.__name__
self.topk = topk
self.every_n_epochs = every_n_epochs
self.every_n_batches = every_n_batches
self.last = last
self.max = max if max is not None else 0
self.ckpt_queue = []
def on_after_trainer_initialized(self, trainer):
if self.topk_saver.topk_queue and trainer.evaluators is None:
logger.warning(
f'You set `topk={self.topk}`, but `eval_dataset` is '
'not set in Trainer.')
def on_evaluate_end(self, trainer, results):
self.topk_saver.save_topk(trainer, results)
def on_train_epoch_end(self, trainer):
if (trainer.epoch_idx + 1) % self.every_n_epochs == 0:
folder_name = f'epoch_{trainer.epoch_idx + 1}'
self.topk_saver.save(trainer, folder_name=folder_name)
self.ckpt_queue.append(folder_name)
if self.max > 0 and len(self.ckpt_queue) > self.max:
self.topk_saver.rm(self.ckpt_queue.pop(0))
if self.last:
folder_name = f'last'
self.topk_saver.save(trainer, folder_name=folder_name)
def on_train_batch_end(self, trainer, loss):
if (trainer.batch_idx + 1) % self.every_n_batches == 0:
folder_name = f'epoch_{trainer.epoch_idx}' \
f'-batch_{trainer.batch_idx + 1}'
self.topk_saver.save(trainer, folder_name=folder_name)
self.ckpt_queue.append(folder_name)
if self.max > 0 and len(self.ckpt_queue) > self.max:
self.topk_saver.rm(self.ckpt_queue.pop(0))
def on_save_checkpoint(self, trainer) -> Dict:
states = {}
states['topk_saver'] = self.topk_saver.state_dict()
return states
def on_load_checkpoint(self, trainer, states):
topk_saver_states = states['topk_saver']
self.topk_saver.load_state_dict(topk_saver_states)