Shortcuts

collie.models.moss_moon.model 源代码

import json
import os
from collections import OrderedDict
from typing import Optional, Tuple, Union

import torch
from deepspeed.pipe import LayerSpec
from torch import distributed as dist
from torch import nn
from transformers.activations import NewGELUActivation
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)

from collie.config import CollieConfig
from collie.driver.io import IODriver
from collie.log import logger
from collie.models.base import CollieModelForCausalLM
from collie.module import (
    ColumnParallelLinearWithoutBias,
    ColumnParallelLMHead,
    RowParallelLinearWithoutBias,
    VocabParallelEmbedding,
)
from collie.utils import env, progress
from collie.utils.utils import concat_tensor, dict_as_params, stack_tensor

from .utils import (
    _state_dict_to_load,
    _state_dict_to_save,
    _weight_name_in_current_rank,
    apply_rotary_pos_emb,
    create_sinusoidal_positions,
    set_index_dict,
)

__all__ = ["Moss003MoonForCausalLM"]


class MossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        max_positions = config.n_positions
        self.register_buffer(
            "causal_mask",
            torch.tril(
                torch.ones((max_positions, max_positions), dtype=torch.bool)
            ).view(1, 1, max_positions, max_positions),
            persistent=False,
        )

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        self.embed_dim = config.n_embd
        self.num_attention_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_attention_heads
        if self.head_dim * self.num_attention_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
                f" `num_attention_heads`: {self.num_attention_heads})."
            )
        self.scale_attn = torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32)
        ).to(torch.get_default_dtype())
        self.qkv_proj = ColumnParallelLinearWithoutBias(
            self.embed_dim,
            self.embed_dim * 3,
            bias=False,
            gather_output=True,
        )

        self.out_proj = RowParallelLinearWithoutBias(
            self.embed_dim, self.embed_dim, bias=False, input_is_parallel=False
        )
        self.rotary_dim = config.rotary_dim
        pos_embd_dim = self.rotary_dim or self.embed_dim
        self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

    def _split_heads(self, x, n_head, dim_head, mp_num):
        reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
        reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
        return reshaped

    def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into n_ctx
        """
        if len(tensor.shape) == 5:
            tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
        elif len(tensor.shape) == 4:
            tensor = tensor.permute(0, 2, 1, 3).contiguous()
        else:
            raise ValueError(
                f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
            )
        new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
        return tensor.view(new_shape)

    def _attn(
        self,
        query,
        key,
        value,
        attention_mask=None,
        head_mask=None,
    ):
        # compute causal mask from causal mask buffer
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.causal_mask[
            :, :, key_length - query_length : key_length, :key_length
        ]

        # Keep the attention weights computation in fp32 to avoid overflow issues
        query = query.to(torch.float32)
        key = key.to(torch.float32)

        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        attn_weights = attn_weights / self.scale_attn
        mask_value = torch.finfo(attn_weights.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
            attn_weights.device
        )
        attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)
        attn_weights = attn_weights.to(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def forward(
        self,
        hidden_states: Optional[torch.FloatTensor],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor, Tuple[torch.Tensor]],
        Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
    ]:
        qkv = self.qkv_proj(hidden_states)
        # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
        mp_num = 4
        qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))

        local_dim = self.head_dim * self.num_attention_heads // mp_num
        query, value, key = torch.split(qkv_split, local_dim, dim=-1)
        query = self._split_heads(
            query, self.num_attention_heads, self.head_dim, mp_num=mp_num
        )
        key = self._split_heads(
            key, self.num_attention_heads, self.head_dim, mp_num=mp_num
        )

        value = self._split_heads(
            value, self.num_attention_heads, self.head_dim, mp_num=mp_num
        )

        embed_positions = self.embed_positions
        if embed_positions.device != position_ids.device:
            embed_positions = embed_positions.to(position_ids.device)
            self.embed_positions = embed_positions

        sincos = embed_positions[position_ids]
        sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

        if self.rotary_dim is not None:
            k_rot = key[:, :, :, : self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim :]

            q_rot = query[:, :, :, : self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim :]

            k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
            q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

            key = torch.cat([k_rot, k_pass], dim=-1)
            query = torch.cat([q_rot, q_pass], dim=-1)
        else:
            key = apply_rotary_pos_emb(key, sin, cos)
            query = apply_rotary_pos_emb(query, sin, cos)

        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=1)
            value = torch.cat((past_value, value), dim=1).to(value.dtype)

        if use_cache is True:
            present = stack_tensor([key, value]).to(key.device)
        else:
            present = None

        key = key.permute(0, 2, 1, 3)
        query = query.permute(0, 2, 1, 3)
        value = value.permute(0, 2, 1, 3)

        # compute self-attention: V x Softmax(QK^T)
        attn_output, attn_weights = self._attn(
            query, key, value, attention_mask, head_mask
        )

        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_dim
        )
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)

        return outputs  # a, present


class MossMLP(nn.Module):
    def __init__(
        self, intermediate_size, config
    ):  # in MLP: intermediate_size= 4 * embed_dim
        super().__init__()
        embed_dim = config.n_embd

        self.fc_in = ColumnParallelLinearWithoutBias(
            embed_dim,
            intermediate_size,
            gather_output=False,
        )
        self.fc_out = RowParallelLinearWithoutBias(
            intermediate_size,
            embed_dim,
            input_is_parallel=True,
        )

        self.act = NewGELUActivation()
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states) -> torch.FloatTensor:
        hidden_states = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc_out(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class MossBlock(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = MossAttention(config)
        self.mlp = MossMLP(inner_dim, config)
        self.config = config
        self.idx = layer_idx

        self.use_cache = False
        self.hidden_states = None

    def _forward(
        self,
        hidden_states: Optional[torch.FloatTensor],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor],
        Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
    ]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states=hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            use_cache=use_cache,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]

        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_output + feed_forward_hidden_states + residual

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions)

    def forward(self, inputs):
        hidden_states = inputs["hidden_states"]
        attention_mask = inputs.get("attention_mask", None)
        past_key_values = inputs.get("past_key_values", None)
        new_past_key_values = inputs.get("new_past_key_values", None)
        if past_key_values is not None:
            layer_past = past_key_values[self.idx]
        else:
            layer_past = None
        if not self.training:
            self.hidden_states = hidden_states

        end_pos = hidden_states.shape[1]

        if past_key_values is None:
            past_length = 0
        else:
            past_length = past_key_values[0][0].size(1)
        position_ids = inputs.get("position_ids", None)
        if position_ids is None:
            position_ids = torch.arange(
                past_length, end_pos + past_length, dtype=torch.long
            ).cuda()
            position_ids = position_ids.unsqueeze(0).view(-1, end_pos)

        if self.config.gradient_checkpointing and self.training:
            self.use_cache = False

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    # None for past_key_value
                    return module(*inputs)

                return custom_forward

            outputs = torch.utils.checkpoint.checkpoint(
                create_custom_forward(self._forward),
                hidden_states,
                None,
                attention_mask,
                position_ids,
            )
        else:
            outputs = self._forward(
                hidden_states,
                position_ids=position_ids,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=None,
                use_cache=self.use_cache,
            )

        if self.use_cache:
            present = outputs[1].unsqueeze(0)
            # 1, 2, bsz, seqlen, ...
            if new_past_key_values is None:
                # 第一层
                new_past_key_values = present
            else:
                # 后续几层
                new_past_key_values = concat_tensor([new_past_key_values, present]).to(
                    present.device
                )

        # hidden_states
        output = {"hidden_states": outputs[0]}
        if attention_mask is not None:
            output["attention_mask"] = attention_mask
        if position_ids is not None:
            output["positions_ids"] = position_ids
        if past_key_values is not None:
            output["past_key_values"] = past_key_values
        if new_past_key_values is not None:
            output["new_past_key_values"] = new_past_key_values

        return output


class Moss003MoonModel(nn.Module):
    def __init__(self, config):
        super(Moss003MoonModel, self).__init__()
        self.embed_dim = config.n_embd
        self.vocab_size = config.vocab_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([MossBlock(config, i) for i in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids,
        attention_mask=None,
        inputs_embeds=None,
        past_key_values=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            dtype = self.wte.weight.dtype
            attention_mask = attention_mask.view(batch_size, -1)
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype)
            attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        hidden_states = self.drop(inputs_embeds)

        all_hidden_states = ()
        input_dict = {
            "hidden_states": hidden_states,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
        }
        for i, l in enumerate(self.h):
            all_hidden_states += (input_dict["hidden_states"],)
            input_dict.update(l(input_dict))

        hidden_states = self.ln_f(input_dict["hidden_states"])
        all_hidden_states += (hidden_states,)

        past_key_values = None
        if "new_past_key_values" in input_dict:
            past_key_values = input_dict["new_past_key_values"]

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            past_key_values=past_key_values,
        )

    @classmethod
    def pipeline_layers(cls, config):
        if isinstance(config, str):
            config = CollieConfig.from_pretrained(config)

        def pre_forward(input_dict):
            input_ids = input_dict["input_ids"]
            attention_mask = input_dict.get("attention_mask", None)
            batch_size = input_ids.shape[0]
            if attention_mask is not None:
                if batch_size <= 0:
                    raise ValueError("batch_size has to be defined and > 0")
                attention_mask = attention_mask.view(batch_size, -1)
                attention_mask = attention_mask[:, None, None, :]
                dtype = torch.float32
                if "fp16" in config.ds_config:
                    if config.ds_config["fp16"].get("enabled", False):
                        dtype = torch.float16
                if "bf16" in config.ds_config:
                    if config.ds_config["bf_16"].get("enabled", False):
                        dtype = torch.bfloat16
                attention_mask = attention_mask.to(dtype=dtype)
                attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min
                input_dict["attention_mask"] = attention_mask
            return input_dict

        wte = dict_as_params("input_ids", "hidden_states")(
            VocabParallelEmbedding, config.vocab_size, config.n_embd
        )
        drop = dict_as_params("hidden_states", "hidden_states")(
            nn.Dropout, config.embd_pdrop
        )
        h = [LayerSpec(MossBlock, config, i) for i in range(config.n_layer)]
        ln_f = dict_as_params("hidden_states", "hidden_states")(
            nn.LayerNorm, config.n_embd, eps=config.layer_norm_epsilon
        )

        layers = [
            ("pre_forward", pre_forward),
            ("wte", wte),
            ("drop", drop),
            ("h", h),
            ("ln_f", ln_f),
        ]

        return layers


[文档]class Moss003MoonForCausalLM(CollieModelForCausalLM): """ 支持 3D 并行的 Moss-moon 模型。 :param config: :class:`.CollieConfig` """ base_model_prefix = "transformer" def __init__(self, config): super().__init__(config) self.transformer = Moss003MoonModel(config) self.lm_head = ColumnParallelLinearWithoutBias(config.n_embd, config.vocab_size) def forward( self, input_ids, attention_mask=None, inputs_embeds=None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs, ): output = self.transformer( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs, ) hidden_states = output.last_hidden_state all_hidden_states = output.hidden_states logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, logits=logits, past_key_values=output.past_key_values, hidden_states=all_hidden_states, attentions=None, ) def set_cache(self, use_cache: bool = False): self._set_use_cache(self.transformer.h, use_cache) def prepare_inputs( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, past_key_values: Optional[list] = None, **kwargs, ): if past_key_values is not None: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values is not None: position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "use_cache": use_cache, "position_ids": position_ids, "past_key_values": past_key_values, } def clean_cache(self): self._clean_hidden_states([*self.transformer.h, self.lm_head]) self._set_use_cache(self.transformer.h, False) @classmethod def pipeline_layers(cls, config): if isinstance(config, str): config = CollieConfig.from_pretrained(config) transformer = Moss003MoonModel.pipeline_layers(config) lm_head = dict_as_params("hidden_states", "logits")( ColumnParallelLMHead, config.n_embd, config.vocab_size, ) layers = [("transformer", transformer), ("lm_head", lm_head)] return layers @staticmethod def load_parallel_state_dict( path: str, config: Union[CollieConfig, str], process_exclusion: bool = False, **kwargs, ): ...
[文档] @staticmethod def load_parallel_state_dict( path: str, config: Union[CollieConfig, str], process_exclusion: bool = False, protocol: str = "file", **kwargs, ): """ 从 ``path`` 中加载模型权重。``path`` 中的模型权重应当是 huggingface 格式。 :param path: :param config: :param process_exclusion: 是否每个 rank 各自独立、互斥地加载模型权重。在模 型规模较大时,该参数可以帮助节省内存。 :return: 一个字典,每个字典都包含当前 rank 上模型需要的权重。 """ # Actually Moss only supports `hf` format if isinstance(config, str): config = CollieConfig.from_pretrained(config) io_driver = IODriver.from_protocol(protocol) if not io_driver.exists(path) and protocol == "file": raise FileNotFoundError(f"folder {path} not found.") # 如果开启了进程互斥,那么每个进程都会显示进度条,否则只显示 RANK0 的 hide_progress = not process_exclusion and env.rank != 0 for cur_rank in range(dist.get_world_size()): if process_exclusion: dist.barrier() if cur_rank != env.rank: continue # 如果存在 pytorch_model.bin.index.json 文件的话,此时不同的 pp 进程可以按需加载自己需要的权重 index_file = os.path.join(path, "pytorch_model.bin.index.json") # start load state_dict = OrderedDict() if io_driver.exists(index_file) and env.is_pipeline: # 有 index 且是流水线 weight_map = json.loads(io_driver.load(index_file, mode="r"))[ "weight_map" ] # layers 表示当前 rank 自己需要的层 cur_names = _weight_name_in_current_rank(weight_map.keys()) weights = set(weight_map[name] for name in cur_names) else: # 如果没有 pytorch_model.bin.index.json 文件的话,那么就加载所有的权重 weights = [ weight for weight in io_driver.list(path) if weight.endswith(".bin") ] desc = "Loading state dict" if process_exclusion: desc += f" on pp={env.pp_rank} tp={env.tp_rank} dp={env.dp_rank}" for weight in progress(weights, desc, disable=hide_progress): part_state_dict = io_driver.load(os.path.join(path, weight), mode="rb") state_dict.update( _state_dict_to_load( part_state_dict, env.tp_rank, config.tp_size, process_exclusion ) ) return state_dict
@staticmethod def save_parallel_state_dict( state_dict: dict, path: str, config: CollieConfig, process_exclusion: bool = False, **kwargs, ): ...
[文档] @staticmethod def save_parallel_state_dict( state_dict: dict, path: str, config: CollieConfig, process_exclusion: bool = False, protocol: str = "file", ): """ 将模型权重保存到 ``path`` 路径。保存的格式同 ``huggingface`` 格式。 在保存时会在 dp rank 0 上将所有张量并行的权重合并至 tp_rank 0,然后按照流水 线的各个阶段分别保存为 sharded checkpoint 的形式。 :param state_dict: 模型权重 :param path: :param config: :param process_exclusion: 是否每个 rank 各自独立、互斥地保存模型权重。在模 型规模较大时,该参数可以帮助节省内存。 """ io_driver = IODriver.from_protocol(protocol) if env.rank == 0: config.save_pretrained(path, protocol=protocol) # gather to tp rank 0 desc = "Saving state dict" # 没有 process_exclusion 的时候就不显示了 hide_progress = not process_exclusion or env.rank != 0 for cur_pp_rank in progress(range(env.pp_size), desc, disable=hide_progress): if process_exclusion: dist.barrier() if env.dp_rank != 0: continue # continue execution when dp_rank == 0 if cur_pp_rank != env.pp_rank: continue # continue when pp_rank is available state_dict = _state_dict_to_save( state_dict, env.tp_rank, config.tp_size, env.tp_group, process_exclusion ) if env.tp_rank != 0: continue # save at tp_rank 0 # Save gathered weights if env.is_pipeline: ckpt_name = ( f"pytorch_model-{env.pp_rank+1:05d}-of-{config.pp_size:05d}.bin" ) index_dict = set_index_dict(state_dict, ckpt_name) index_dicts = [None for _ in range(env.pp_size)] dist.gather_object( index_dict, index_dicts if env.pp_rank == 0 else None, group=env.pp_group ) if env.pp_rank == 0: total_size = 0 weight_map = {} for _index_dict in index_dicts: total_size += _index_dict["total_size"] weight_map.update(_index_dict["weight_map"]) merged_dict = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } io_driver.save( json.dumps(merged_dict, indent=2, sort_keys=True) + "\n", os.path.join(path, "pytorch_model.bin.index.json"), ) else: ckpt_name = f"pytorch_model.bin" ckpt_path = os.path.join(path, ckpt_name) io_driver.save(state_dict, ckpt_path) dist.barrier()