混合精度训练#

作者: HuaqiangZhang

混合精度训练基础#

Mixed Precision Training 是一种在训练期间混合使用全精度和半精度浮点数的技术,以减少内存带宽需求并提高给定模型的计算效率。

该库通过提供两个关键抽象(“policies”和”loss scaling”)实现了对JAX中混合精度训练的支持,GPU及CPU的初始版本源于 JMP , 由于NeurAI支持的网络拓扑类型和硬件类型更为复杂,所以后续将提供更多应用场景下的专属解决方案.

功能加载:

from neurai.mp import policy
from neurai.mp import scale
half = jnp.float16  # On TPU this should be jnp.bfloat16
full = jnp.float32

策略(Policies)#

用户可以使用``policy``类进行基本的策略配置.

# Our policy specifies that we will store parameters in full precision but will compute and return output in half precision.
my_policy = policy.Policy(param_dtype=full, compute_dtype=half, output_dtype=half)

policy对象可用于强制转换pytree:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

可以替换给定策略的输出类型:

my_policy = my_policy.with_output_dtype(full)

可以通过字符串定义策略,在某些场景,譬如将策略作为命令行参数或训练时的超参数时,这个功能可能会很有用:

my_policy = policy.get_policy("params=float32,compute=float16,output=float32")
float16 = policy.get_policy("float16")  # Everything in f16.
half = policy.get_policy("half")        # Everything in half (f16 or bf16).

损失缩放(Loss scaling)#

当训练精度降低时,用户可能会使用这个功能,将梯度缩放到正在使用的格式的可表示范围内。 在使用float16进行训练时这个设置功能尤为重要,而对于bfloat16则不那么重要。 具体可以参考NVIDIA官方的文字 Training With Mixed Precision .

静态损失缩放(StaticLossScale),可以按照用户设置,将损失和梯度固定缩放为S和1/S:

def my_loss_fn(params, loss_scale: scale.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: scale.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = scale.StaticLossScale(jnp.float32(2 ** 12))
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

动态损失缩放(DynamicLossScale),可以在训练过程中周期性地调整损失尺度,以找到产生有限梯度的S的最大值。与选择静态损失规模相比,这更方便、更健壮,但对性能的影响很小(在1%到5%之间):

def my_loss_fn(params, loss_scale: scale.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: scale.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, scale.DynamicLossScale)

  if skip_nonfinite_updates:
      grads_finite = scale.all_finite(grads)
      # Adjust our loss scale depending on whether gradients were finite. The
      # loss scale will be periodically increased if gradients remain finite and
      # will be decreased if not.
      loss_scale = loss_scale.adjust(grads_finite)
      # Only apply our optimizer if grads are finite, if any element of any
      # gradient is non-finite the whole update is discarded.
      params = scale.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
      # With static or no loss scaling just apply our optimizer.
      params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

  loss_scale = scale.DynamicLossScale(jnp.float32(2 ** 16))
  for _ in range(num_steps):
    params, loss_scale = train_step(params, loss_scale, ...)

NoOpLossScale提供了一个空操作,在执行时什么都不会做:

loss_scale = scale.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1