collie.optim.lomo 源代码
import torch
from torch.optim import Optimizer
import torch.distributed as dist
try:
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ..utils.dist_utils import env
from collie.log import logger
[文档]class Lomo(Optimizer):
"""
一个自定义的优化器类Lomo,用于在分布式训练中的梯度更新。
该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。
:param model: 待优化的模型
:param lr: 学习率,默认值为1e-3
:param clip_grad_norm: 梯度裁剪的范数阈值
.. note::
clip_grad_norm须为正数
:param zero3_enabled: 是否开启了 zero3
:param clip_grad_value: 梯度裁剪的值域阈值
:param loss_scale_args: 用于初始化 :class:`DynamicLossScaler` 的参数
"""
def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None, weight_decay=0.0, loss_scale_args={}):
self.model = model
self.lr = lr
self.clip_grad_norm = clip_grad_norm
self.clip_grad_value = clip_grad_value
self.loss_scaler = None
self.loss_scale_args = loss_scale_args
self.weight_decay = weight_decay
if self.weight_decay > 0.0:
self.do_weight_decay = True
else:
self.do_weight_decay = False
# for grad norm
if self.clip_grad_norm is not None and self.clip_grad_norm <= 0:
raise ValueError(f"clip_grad_norm should be positive, got {self.clip_grad_norm}.")
self.gather_norm = False
self.grad_norms = []
self.clip_coef = None
# check if zero3 is enabled
self.zero3_enabled = is_deepspeed_zero3_enabled()
if self.zero3_enabled: # zero3 is enabled
self.grad_func = self.fuse_update_zero3()
else:
self.grad_func = self.fuse_update()
self.first_backward = True # check bf16 or fp16 in the first backward
# register hook function, which will be called through the backward process
for n, p in self.model.named_parameters():
if p.requires_grad:
p.register_hook(self.grad_func)
defaults = dict(lr=lr, clip_grad_norm=clip_grad_norm, clip_grad_value=clip_grad_value)
super(Lomo, self).__init__(self.model.parameters(), defaults)
@property
def dp_rank(self):
return env.dp_rank
@property
def world_size(self):
return env.world_size
[文档] def fuse_update(self):
"""
在非ZeRO模式下更新模型参数的梯度。
:return: func,一个闭包函数,用于更新模型参数的梯度
"""
def func(x):
"""
闭包函数,用于更新模型参数的梯度。
"""
with torch.no_grad():
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
if self.loss_scaler and (self.loss_scaler.has_overflow_serial or self.loss_scaler._has_inf_or_nan(p.grad)):
# if the overflow is detected, drop the gradient
p.grad = None
self.loss_scaler.has_overflow_serial = True
break
grad_fp32 = p.grad.to(torch.float32)
p.grad = None
if self.loss_scaler:
grad_fp32.div_(self.loss_scaler.loss_scale)
if self.gather_norm:
# we adopt two backward pass for gradient norm compuation and parameter update, respectively.
self.grad_norms.append(torch.norm(grad_fp32, 2.0))
else:
if self.clip_grad_value is not None and self.clip_grad_value > 0:
# Clipping gradients by their value
grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
# Normalize the gradient according to its norm (computed in another pass)
grad_fp32.mul_(self.clip_coef)
p_fp32 = p.data.to(torch.float32)
if self.do_weight_decay:
p_fp32.mul_(1.0 - self.lr * self.weight_decay)
p_fp32.add_(grad_fp32, alpha=-self.lr)
p.data.copy_(p_fp32)
return x
return func
[文档] def fuse_update_zero3(self):
"""
在ZeRO模式下更新模型参数的梯度。
:return: func,一个闭包函数,用于更新模型参数的梯度。
"""
def func(x):
with torch.no_grad():
for n, p in self.model.named_parameters():
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False)
if self.loss_scaler and (self.loss_scaler.has_overflow_serial or self.loss_scaler._has_inf_or_nan(p.grad)):
# if the overflow is detected, drop the gradient
p.grad = None
self.loss_scaler.has_overflow_serial = True
break
grad_fp32 = p.grad.to(torch.float32)
p.grad = None
param_fp32 = p.ds_tensor.to(torch.float32)
if self.loss_scaler:
grad_fp32.div_(self.loss_scaler.loss_scale)
if self.gather_norm:
# we adopt two backward pass for gradient norm compuation and parameter update, respectively.
self.grad_norms.append(torch.norm(grad_fp32, 2.0))
else: # update param
one_dim_grad_fp32 = grad_fp32.view(-1)
partition_size = p.ds_tensor.numel()
start = partition_size * self.dp_rank
end = min(start + partition_size, grad_fp32.numel())
partitioned_grad_fp32 = one_dim_grad_fp32.narrow(0, start, end - start)
if self.clip_grad_value is not None:
# Clipping gradients by their value
partitioned_grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
# Normalize the gradient according to its norm (computed in another pass)
partitioned_grad_fp32.mul_(self.clip_coef)
partitioned_p = param_fp32.narrow(0, 0, end - start)
if self.do_weight_decay:
partitioned_p.mul_(1.0 - self.lr * self.weight_decay)
partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr)
p.ds_tensor[ : end - start] = partitioned_p
return x
return func
[文档] def fused_backward(self, loss, lr):
"""
执行一步反向传播并更新模型的梯度。
:param loss: 模型的loss值
:param lr: 学习率
"""
if self.first_backward:
self.first_backward = False
if loss.dtype == torch.float16:
self.loss_scaler = DynamicLossScaler(**self.loss_scale_args)
if self.clip_grad_norm is None:
self.clip_grad_norm = 1.0
logger.rank_zero_warning(
"Loss scale is recommended to be used with grad norm to get better performance. "
"Set grad norm to 1.0."
)
self.lr = lr
# Users need call grad_norm themselves and then call backward_step
if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is None:
raise ValueError(
"clip_grad_norm is not None, but clip_coef is None. "
"Please call optimizer.grad_norm() before optimizer.fused_backward()."
)
if self.loss_scaler:
loss = loss * self.loss_scaler.loss_scale
loss.backward()
# update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions
# the argument of grad_func is just a placeholder, and it can be anything.
self.grad_func(0)
[文档] def grad_norm(self, loss):
"""
计算梯度的范数。
:param loss: 模型的loss值
"""
if self.first_backward:
self.first_backward = False
if loss.dtype == torch.float16:
self.loss_scaler = DynamicLossScaler(**self.loss_scale_args)
self.gather_norm = True
self.grad_norms = []
if self.loss_scaler:
self.loss_scaler.has_overflow_serial = False
loss = loss * self.loss_scaler.loss_scale
loss.backward(retain_graph=True)
# update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions
# the argument of grad_func is just a placeholder, and it can be anything.
self.grad_func(0)
if self.loss_scaler and self.loss_scaler.has_overflow_serial:
self.loss_scaler.update_scale(overflow=True)
with torch.no_grad(): # clear gradients
for n, p in self.model.named_parameters():
p.grad = None
return
with torch.no_grad():
# The norm is computed over all gradients together, as if they were
# concatenated into a single vector. Gradients are modified in-place.
self.grad_norms = torch.stack(self.grad_norms)
total_norm = torch.norm(self.grad_norms, 2.0)
self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6)
self.clip_coef = torch.clamp(self.clip_coef, max=1.0)
self.gather_norm = False
class DynamicLossScaler:
"""
动态loss缩放器,用于在训练过程中动态调整loss的缩放比例。
:param init_scale: 初始缩放比例
:param scale_factor: 缩放因子
:param scale_window:
:param min_scale: 最小缩放比例,默认为1
:param delayed_shift: 延迟移位,默认为1
:param consecutive_hysteresis: 是否启用连续的滞后效应,默认为False。如果是True,在处理梯度溢出时会滞后 :attr:`delayed_shift` 个迭代周期。
:param raise_error_at_min_scale: 最小缩放比例时是否抛出异常,默认为True
:param dtype: 数据类型,默认为torch.half
"""
def __init__(self,
init_scale=2 ** 32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False,
raise_error_at_min_scale=True,
dtype=torch.half):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
self.raise_error_at_min_scale = raise_error_at_min_scale
self.dtype = dtype
self.has_overflow_serial = False
@property
def loss_scale(self):
return self.cur_scale
# `x` is a torch.Tensor
def _has_inf_or_nan(self, x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
raise Exception(
"Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.")
else:
next_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
if torch.distributed.get_rank() == 0:
overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}"
print(overflow_msg)
self.cur_scale = next_scale
else:
if torch.distributed.get_rank() == 0:
overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis - 1}"
print(overflow_msg)
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
if torch.distributed.get_rank() == 0:
hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}"
print(hysteresis_msg)
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1