Shortcuts

Trainer

class Trainer(model, config, tokenizer=None, loss_fn=GPTLMLoss(   (loss): CrossEntropyLoss() ), train_fn=None, eval_fn=None, optimizer=None, lr_scheduler=None, train_dataset=None, eval_dataset=None, callbacks=None, train_dataset_collate_fn=<collie.utils.padder.ColliePadder object>, eval_dataset_collate_fn=<collie.utils.padder.ColliePadder object>, server=None, monitors=[], metrics=None, evaluators=None)[源代码]

CoLLie 训练器,支持快速分布式训练和验证。

参数:
  • model (Module) –

    用于训练和验证的模型,可以使用 CoLLie 实现的模型或 transformers 提供的模型:

    • CoLLie 实现的模型 CollieModelForCausalLM 可支持的并行方式包括:张量并行、流水线并行、ZeRO

    • transformers 提供的模型 transformers.PreTrainedModel 只支持 ZeRO

  • config (CollieConfig) – 用于训练和验证的配置

  • tokenizer (Optional[PreTrainedTokenizerBase], default: None) – 用于训练和验证的分词器,该分词器将用于: * 保存模型时 trainer.save_model 时自动同时保存 tokenizer * 使用 EvaluatorForGeneration 进行基于生成的验证时,使用 tokenizer 对生成的结果进行解码 若无上述需求,可不传入 tokenizer

  • loss_fn (Callable, default: GPTLMLoss(   (loss): CrossEntropyLoss() )) – 用于计算 loss 的函数,默认使用 GPTLMLoss()

  • train_fn (Optional[Callable], default: None) – 用于训练的函数,默认使用 train_fn()

  • eval_fn (Optional[Callable], default: None) –

    用于验证的函数

    备注

    CoLLie 未提供默认的验证策略,若未传入 eval_fn,但传入了 eval_dataset,则会抛出异常。若不需要自定义验证循环, 可以考虑使用 CoLLie 定义的多种验证器,例如 EvaluatorForPerplexityEvaluatorForClassficationEvaluatorForGeneration 等。

  • optimizer (Optional[Optimizer], default: None) – 训练过程中的优化器,当为 None 的时候会尝试使用 config.ds_config 定义的优化器

  • lr_scheduler (Union[_LRScheduler, Callable[[Optimizer], _LRScheduler], None], default: None) – 训练过程中的学习率调度器;

  • train_dataset (Optional[Dataset], default: None) – 用于训练的数据集。

  • eval_dataset (Optional[Dataset], default: None) –

    用于验证的数据集。 CoLLie 可接收的 train_dataseteval_dataset 为可迭代对象,例如 torch.utils.data.DatasetList。 可以使用 CollieDatasetForTraining 快速将数据集转换为 CoLLie 可接收的数据集。

    备注

    当未提供 train_dataset_collate_fneval_dataset_collate_fn 时,train_dataseteval_dataset 的取值应当为 Dict 类型

    注意: 上述数据格式为训练所需的格式, 同时 CoLLie 提供了多种验证器, 所要求的格式各有不同, 详见 Evaluator

  • callbacks (Union[Callback, List[Callback], None], default: None) – 训练中触发的 Callback 类,可以是列表。

  • train_dataset_collate_fn (Optional[Callable], default: <collie.utils.padder.ColliePadder object at 0x7f06c7bc09d0>) – 用于训练数据集的 collate_fn

  • eval_dataset_collate_fn (Optional[Callable], default: <collie.utils.padder.ColliePadder object at 0x7f06c7bc0880>) –

    用于验证数据集的 collate_fntrain_dataset_collate_fneval_dataset_collate_fn 只可接受一个参数,为 train_dataseteval_dataset 迭代值组成的 List

    备注

    train_dataset_collate_fneval_dataset_collate_fn 的返回值必须是 Dict 类型

    注意: 上述数据格式为训练所需的格式, 同时 CoLLie 提供了多种验证器, 所要求的格式各有不同, 详见 Evaluator

    例如:

    from transformers import AutoTokenizer
    def collate_fn(batch):
        # batch = ["样本1", "样本2", ...]
        tokenizer = AutoTokenizer.from_pretrained("fnlp/moss-moon-003-sft", padding_side="left", trust_remote_code=True)
        input_ids = tokenizer(batch, return_tensors="pt", padding=True)["input_ids"]
        return {"input_ids": input_ids, "labels": input_ids}
    

  • server (Optional[Server], default: None) – 用于打开一个交互界面,随时进行生成测试,详见 Server

  • monitors (Sequence[BaseMonitor], default: []) – 用于监控训练过程的监控器,详见 BaseMonitor

  • metrics (Optional[Dict], default: None) –

    用于传给 Trainer 内部训练过程中的对 eval_dataset 进行验证。 其应当为一个字典,其中 key 表示 monitor,value 表示一个 metric,例如 {"acc1": Accuracy(), "acc2": Accuracy()}

    目前我们支持的 metric 的种类有以下几种:

    • Collie 自己的 metric:详见 BaseMetric

    • 继承 Collie 基类的自定义 Metric

  • evaluators (Optional[List], default: None) – 验证器。当传入多个 Evaluator 时会依次执行 evaluator 的验证方法。

callback_manager: CallbackManager
init_state_dict()[源代码]

初始化优化器的自身状态字典

state_dict()[源代码]

获取优化器的自身状态字典

load_state_dict(state_dict)[源代码]

加载优化器的自身状态

property global_batch_idx

获取当前全局步数

setup_parallel_model()[源代码]

初始化分布式模型。

train(dataloader=None)[源代码]

训练循环

参数:

dataloader (Optional[Iterable], default: None) – 用于训练的数据集,为 Iterable 对象 ,当为 None 时,使用由 train_dataset 生成的 train_dataloader

eval(dataloader=None)[源代码]

验证循环

参数:

dataloader (Optional[Iterable], default: None) – 用于验证的数据集,为 Iterable 对象 ,当为 None 时,使用 eval_dataset 生成的 eval_dataloader

static train_fn(trainer, batch, global_step)[源代码]

一次训练的基本单元

参数:
  • trainer – 训练器

  • batch (Dict) –

    一个 batch 的数据,类型为 Dict

    备注

    根据提供的 train_datasettrain_dataset_collate_fn 的不同,labels 的类型也会有所不同,详见 Trainer

  • global_step (int) – 当前的全局步数

返回类型:

float

返回:

当前 batch 的 loss

save_peft(path, selected_adapters=None, process_exclusion=False, protocol='file', **kwargs)[源代码]

保存 adapter 部分权重,当未使用 peft 时,该方法等同于 save_model

参数:
  • path (str) – 模型保存路径

  • selected_adapters (Optional[List[str]], default: None) – 保存时保存哪些 adapter;为 None 时则会保存 所有的 adapter

  • process_exclusion (bool, default: False) –

load_peft(path, adapter_name='default', is_trainable=False, process_exclusion=False, protocol='file', **kwargs)[源代码]

加载 adapter 部分权重,当未使用 peft 时,该方法等同于 load_model

参数:
  • path (str) – 模型保存路径

  • adapter_name (default: 'default') – 当前加载的 adapter 名称

  • is_trainable (bool, default: False) – 是否允许加载的 adapter 进行训练

  • process_exclusion (bool, default: False) –

save_model(path, process_exclusion=False, protocol='file', **kwargs)[源代码]

保存模型。

参数:
  • path (str) – 模型保存路径

  • process_exclusion (bool, default: False) – 是否开启进程互斥,当开启流水线并行时开启此项可以节省内存(仅限 CoLLie 内实现的模型,对 transformers 提供的模型本项无效)

save_checkpoint(path, process_exclusion=False, protocol='file', **kwargs)[源代码]

保存训练器断点功能

参数:
  • path (str) – 断点保存路径

  • process_exclusion (bool, default: False) – 是否开启进程互斥,当开启流水线并行时开启此项可以节省内存(仅限 CoLLie 内实现的模型,对 transformers 提供的模型本项无效)

load_checkpoint(path, process_exclusion=False, protocol='file', **kwargs)[源代码]

训练器断点加载

参数:
  • path (str) – 断点保存路径

  • process_exclusion (bool, default: False) – 是否开启进程互斥,当开启流水线并行时开启此项可以节省内存(仅限 CoLLie 内实现的模型,对 transformers 提供的模型本项无效)