neurai.mp package#
Submodules#
This is policy of mixed precision at NeurAI.
- class neurai.mp.policy.Policy(param_dtype, compute_dtype, output_dtype)#
Bases:
object
Encapsulates casting for inputs, outputs and parameters.
- Parameters:
param_dtype (jnp.dtype) – The cast_to_param dtype of Policy.
compute_dtype (jnp.dtype) – The cast_to_compute dtype of Policy.
output_dtype (jnp.dtype) – The cast_to_output dtype of Policy.
- cast_to_compute(x)#
Converts floating point values to the compute dtype.
- Parameters:
x (T) – The tree structure to be converted.
- Return type:
TypeVar
(T
)- Returns:
Any – The converted tree structure base on compute_dtype.
- cast_to_output(x)#
Converts floating point values to the output dtype.
- Parameters:
x (T) – The tree structure to be converted.
- Return type:
TypeVar
(T
)- Returns:
Any – The converted tree structure base on output_dtype.
- neurai.mp.policy.cast_to_full(tree)#
Ensures floating point leaves of the given tree are f32.
Parameters:#
- treeT
The tree structure to be converted.
- neurai.mp.policy.cast_to_half(tree)#
Ensures floating point leaves of the given tree are half precision.
Parameters:#
- treeT
The tree structure to be converted.
- neurai.mp.policy.get_policy(policy_name)#
Get the Policy type based on policy_name string.
- Parameters:
policy_name (str) –
- Policy of string type.
- Loose grammar supporting:
”c=f16” (params full, compute+output in f16),
”p=f16,c=f16” (params, compute and output in f16).
”p=f16,c=bf16” (params in f16, compute in bf16, output in bf16)
For values that are not specified params defaults to f32, compute follows params and output follows compute (e.g. ‘c=f16’ -> ‘p=f32,c=f16,o=f16’).
- Return type:
- Returns:
Policy – A mixed precision policy parsed from a string.
- Raises:
ValueError – If unknown key ‘{key}’ of ‘{policy_name}’ not be ‘params’, ‘compute’ or ‘output’.
- neurai.mp.policy.half_dtype()#
Returns the half precision dtype for the current backend.
- Return type:
- Returns:
jnp.dtype – The half precision dtype for the current backend.
- neurai.mp.policy.parse_dtype(value)#
Parses a string representing a dtype into a dtype object.
Parameters:#
- valuestr
String abbreviation for floating point type.
This is scale of mixed precision at NeurAI.
- class neurai.mp.scale.DynamicLossScale(loss_scale, counter=<factory>, period=2000, factor=2, min_loss_scale=<factory>)#
Bases:
object
Dynamic loss scale. Dynamic loss scaling tries to determine the largest loss scale value that will keep gradients finite. It does this by increasing the loss scale every period steps by factor if the grads remain finite, otherwise it reduces the loss scale by 1 / factor and resets the counter.
loss_scale = 2 ** 15 counter = 0 period = 2000 factor = 2 for step in range(num_steps): loss *= loss_scale grads /= loss_scale grads_finite = all_finite(grads) if grads_finite: counter += 1 if counter == period: counter = 0 loss_scale = first_finite(loss_scale * factor, loss_scale) else: counter = 0 loss_scale = max(1, loss_scale / factor)
- Parameters:
loss_scale (jnp.ndarray) – The initial loss scale value.
counter (jnp.ndarray, optional) – The counter for adjusting the loss scale, by default np.zeros([], np.int32)
period (int, optional) – The period for adjusting the loss scale, by default 2000
factor (int, optional) – The factor for adjusting the loss scale, by default 2
min_loss_scale (jnp.ndarray, optional) – The minimum loss scale value, by default np.ones([], np.float32)
Examples
>>> loss_scale = DynamicLossScale(jnp.float32(2 ** 15)) >>> for _ in range(num_steps): ... # compute loss ... loss = loss_scale.scale(loss) ... # compute grads ... grads = loss_scale.unscale(grads) ... grads_finite = all_finite(grads) ... loss_scale = loss_scale.adjust(grads_finite) ... # conditionally update params using grads
- adjust(grads_finite)#
Adjusts the loss scale based on whether gradients are finite.
- Parameters:
grads_finite (jnp.ndarray) – Boolean scalar indicating whether gradients are finite.
- Return type:
- Returns:
DynamicLossScale – The next state of the DynamicLossScale object.
- scale(tree)#
Scale the input tree by the loss scale.
- Parameters:
tree (T) – The input tree structure to be scaled.
- Return type:
TypeVar
(T
)- Returns:
Any – The scaled tree structure.
- tree_flatten()#
Flatten the DynamicLossScale object into data and meta components.
- classmethod tree_unflatten(meta, data)#
Unflatten the data and meta components into a DynamicLossScale object.
- Parameters:
meta (_Meta) – The meta component.
data (_Data) – The data component.
- Return type:
- Returns:
DynamicLossScale – The unflattened DynamicLossScale object.
- class neurai.mp.scale.NoOpLossScale#
Bases:
object
No-op loss scale does nothing.
- adjust(grads_finite)#
Adjusts the loss scale.
- Parameters:
grads_finite (jnp.ndarray) – Finite gradients array.
- Returns:
StaticLossScale – The adjusted loss scale object.
- class neurai.mp.scale.StaticLossScale(loss_scale)#
Bases:
object
Scales and unscales by a fixed constant.
- Parameters:
loss_scale (
Array
) –
- adjust(grads_finite)#
Adjusts the loss scale.
- Parameters:
grads_finite (jnp.ndarray) – Finite gradients array.
- Returns:
StaticLossScale – The adjusted loss scale object.
- neurai.mp.scale.all_finite(tree)#
Check if all elements in the tree are finite.
- Parameters:
tree (T) –
structure. (The input tree) –
- Return type:
- Returns:
jnp.ndarray – A boolean scalar indicating whether the input arrays are finite.
- neurai.mp.scale.register_empty_pytree(cls)#
Register a custom Pytree node.
- Parameters:
cls – The class to register as a Pytree node.
- neurai.mp.scale.select_tree(pred, a, b)#
Selects elements from two trees based on a boolean scalar predicate.
- Parameters:
pred (jnp.ndarray) – Boolean scalar indicating which elements to select.
a (T) – First tree structure.
b (T) – Second tree structure.
- Return type:
TypeVar
(T
)- Returns:
Any – A tree structure containing elements selected from either a or b based on the predicate.
- neurai.mp.scale.warn_if_not_floating(x, var_name)#
Produces a warning if the given array does not have a floating type.
This function handles an edgecase where Jax passes in an object() to determine the structure of user defined pytrees during compilation. They recommend explicitly checking if the array in question has the type object.
From the Jax documentation: “The __init__ and __new__ methods of custom PyTree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases.”
See: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization