混合精度训练#
作者: HuaqiangZhang
混合精度训练基础#
Mixed Precision Training 是一种在训练期间混合使用全精度和半精度浮点数的技术,以减少内存带宽需求并提高给定模型的计算效率。
该库通过提供两个关键抽象(“policies”和”loss scaling”)实现了对JAX中混合精度训练的支持,GPU及CPU的初始版本源于 JMP , 由于NeurAI支持的网络拓扑类型和硬件类型更为复杂,所以后续将提供更多应用场景下的专属解决方案.
功能加载:
from neurai.mp import policy
from neurai.mp import scale
half = jnp.float16 # On TPU this should be jnp.bfloat16
full = jnp.float32
策略(Policies)#
用户可以使用``policy``类进行基本的策略配置.
# Our policy specifies that we will store parameters in full precision but will compute and return output in half precision.
my_policy = policy.Policy(param_dtype=full, compute_dtype=half, output_dtype=half)
policy对象可用于强制转换pytree:
def layer(params, x):
params, x = my_policy.cast_to_compute((params, x))
w, b = params
y = x @ w + b
return my_policy.cast_to_output(y)
params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half
可以替换给定策略的输出类型:
my_policy = my_policy.with_output_dtype(full)
可以通过字符串定义策略,在某些场景,譬如将策略作为命令行参数或训练时的超参数时,这个功能可能会很有用:
my_policy = policy.get_policy("params=float32,compute=float16,output=float32")
float16 = policy.get_policy("float16") # Everything in f16.
half = policy.get_policy("half") # Everything in half (f16 or bf16).
损失缩放(Loss scaling)#
当训练精度降低时,用户可能会使用这个功能,将梯度缩放到正在使用的格式的可表示范围内。 在使用float16进行训练时这个设置功能尤为重要,而对于bfloat16则不那么重要。 具体可以参考NVIDIA官方的文字 Training With Mixed Precision .
静态损失缩放(StaticLossScale),可以按照用户设置,将损失和梯度固定缩放为S和1/S:
def my_loss_fn(params, loss_scale: scale.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: scale.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
params = apply_optimizer(params, grads)
return params
loss_scale = scale.StaticLossScale(jnp.float32(2 ** 12))
for _ in range(num_steps):
params = train_step(params, loss_scale, ...)
动态损失缩放(DynamicLossScale),可以在训练过程中周期性地调整损失尺度,以找到产生有限梯度的S的最大值。与选择静态损失规模相比,这更方便、更健壮,但对性能的影响很小(在1%到5%之间):
def my_loss_fn(params, loss_scale: scale.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: scale.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
# You definitely want to skip non-finite updates with the dynamic loss scale,
# but you might also want to consider skipping them when using a static loss
# scale if you experience NaN's when training.
skip_nonfinite_updates = isinstance(loss_scale, scale.DynamicLossScale)
if skip_nonfinite_updates:
grads_finite = scale.all_finite(grads)
# Adjust our loss scale depending on whether gradients were finite. The
# loss scale will be periodically increased if gradients remain finite and
# will be decreased if not.
loss_scale = loss_scale.adjust(grads_finite)
# Only apply our optimizer if grads are finite, if any element of any
# gradient is non-finite the whole update is discarded.
params = scale.select_tree(grads_finite, apply_optimizer(params, grads), params)
else:
# With static or no loss scaling just apply our optimizer.
params = apply_optimizer(params, grads)
# Since our loss scale is dynamic we need to return the new value from
# each step. All loss scales are `PyTree`s.
return params, loss_scale
loss_scale = scale.DynamicLossScale(jnp.float32(2 ** 16))
for _ in range(num_steps):
params, loss_scale = train_step(params, loss_scale, ...)
NoOpLossScale提供了一个空操作,在执行时什么都不会做:
loss_scale = scale.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1