neurai.mp package#

Submodules#

This is policy of mixed precision at NeurAI.

class neurai.mp.policy.Policy(param_dtype, compute_dtype, output_dtype)#

Bases: object

Encapsulates casting for inputs, outputs and parameters.

Parameters:
  • param_dtype (jnp.dtype) – The cast_to_param dtype of Policy.

  • compute_dtype (jnp.dtype) – The cast_to_compute dtype of Policy.

  • output_dtype (jnp.dtype) – The cast_to_output dtype of Policy.

cast_to_compute(x)#

Converts floating point values to the compute dtype.

Parameters:

x (T) – The tree structure to be converted.

Return type:

TypeVar(T)

Returns:

Any – The converted tree structure base on compute_dtype.

cast_to_output(x)#

Converts floating point values to the output dtype.

Parameters:

x (T) – The tree structure to be converted.

Return type:

TypeVar(T)

Returns:

Any – The converted tree structure base on output_dtype.

cast_to_param(x)#

Converts floating point values to the param dtype.

Parameters:

x (T) – The tree structure to be converted.

Return type:

TypeVar(T)

Returns:

Any – The converted tree structure base on param_dtype.

with_output_dtype(output_dtype)#

Updates the output data type of the Policy.

Parameters:

output_dtype (jnp.dtype) – The new output data type.

Return type:

Policy

Returns:

Policy – A new Policy object with the updated output data type.

neurai.mp.policy.cast_to_full(tree)#

Ensures floating point leaves of the given tree are f32.

Parameters:#

treeT

The tree structure to be converted.

rtype:

TypeVar(T)

returns:

Any – The float32 leaves of given tree T.

param tree:

type tree:

TypeVar(T)

neurai.mp.policy.cast_to_half(tree)#

Ensures floating point leaves of the given tree are half precision.

Parameters:#

treeT

The tree structure to be converted.

rtype:

TypeVar(T)

returns:

Any – An half precision leaves of given tree T.

param tree:

type tree:

TypeVar(T)

neurai.mp.policy.get_policy(policy_name)#

Get the Policy type based on policy_name string.

Parameters:

policy_name (str) –

Policy of string type.
Loose grammar supporting:
  • ”c=f16” (params full, compute+output in f16),

  • ”p=f16,c=f16” (params, compute and output in f16).

  • ”p=f16,c=bf16” (params in f16, compute in bf16, output in bf16)

For values that are not specified params defaults to f32, compute follows params and output follows compute (e.g. ‘c=f16’ -> ‘p=f32,c=f16,o=f16’).

Return type:

Policy

Returns:

Policy – A mixed precision policy parsed from a string.

Raises:

ValueError – If unknown key ‘{key}’ of ‘{policy_name}’ not be ‘params’, ‘compute’ or ‘output’.

neurai.mp.policy.half_dtype()#

Returns the half precision dtype for the current backend.

Return type:

dtype

Returns:

jnp.dtype – The half precision dtype for the current backend.

neurai.mp.policy.parse_dtype(value)#

Parses a string representing a dtype into a dtype object.

Parameters:#

valuestr

String abbreviation for floating point type.

rtype:

dtype

returns:

jnp.dtype – The floating point number type corresponding to the string abbreviation.

raises ValueError:

If unknown dtype ‘{value}’ is provided.

param value:

type value:

str

This is scale of mixed precision at NeurAI.

class neurai.mp.scale.DynamicLossScale(loss_scale, counter=<factory>, period=2000, factor=2, min_loss_scale=<factory>)#

Bases: object

Dynamic loss scale. Dynamic loss scaling tries to determine the largest loss scale value that will keep gradients finite. It does this by increasing the loss scale every period steps by factor if the grads remain finite, otherwise it reduces the loss scale by 1 / factor and resets the counter.

loss_scale = 2 ** 15
counter = 0
period = 2000
factor = 2

for step in range(num_steps):
  loss *= loss_scale
  grads /= loss_scale
  grads_finite = all_finite(grads)

  if grads_finite:
    counter += 1
    if counter == period:
      counter = 0
      loss_scale = first_finite(loss_scale * factor, loss_scale)
  else:
    counter = 0
    loss_scale = max(1, loss_scale / factor)
Parameters:
  • loss_scale (jnp.ndarray) – The initial loss scale value.

  • counter (jnp.ndarray, optional) – The counter for adjusting the loss scale, by default np.zeros([], np.int32)

  • period (int, optional) – The period for adjusting the loss scale, by default 2000

  • factor (int, optional) – The factor for adjusting the loss scale, by default 2

  • min_loss_scale (jnp.ndarray, optional) – The minimum loss scale value, by default np.ones([], np.float32)

Examples

>>> loss_scale = DynamicLossScale(jnp.float32(2 ** 15))
>>> for _ in range(num_steps):
...   # compute loss
...   loss = loss_scale.scale(loss)
...   # compute grads
...   grads = loss_scale.unscale(grads)
...   grads_finite = all_finite(grads)
...   loss_scale = loss_scale.adjust(grads_finite)
...   # conditionally update params using grads
adjust(grads_finite)#

Adjusts the loss scale based on whether gradients are finite.

Parameters:

grads_finite (jnp.ndarray) – Boolean scalar indicating whether gradients are finite.

Return type:

DynamicLossScale

Returns:

DynamicLossScale – The next state of the DynamicLossScale object.

scale(tree)#

Scale the input tree by the loss scale.

Parameters:

tree (T) – The input tree structure to be scaled.

Return type:

TypeVar(T)

Returns:

Any – The scaled tree structure.

tree_flatten()#

Flatten the DynamicLossScale object into data and meta components.

Return type:

Tuple[Tuple[Array, ...], Tuple[int, int]]

Returns:

Tuple[_Data, _Meta] – A tuple containing the data and meta components.

classmethod tree_unflatten(meta, data)#

Unflatten the data and meta components into a DynamicLossScale object.

Parameters:
  • meta (_Meta) – The meta component.

  • data (_Data) – The data component.

Return type:

DynamicLossScale

Returns:

DynamicLossScale – The unflattened DynamicLossScale object.

unscale(tree)#

Unscales the input tree by the inverse of the fixed constant.

Parameters:

tree (T) – The input tree structure to be unscaled.

Return type:

TypeVar(T)

Returns:

Any – The unscaled tree structure.

class neurai.mp.scale.NoOpLossScale#

Bases: object

No-op loss scale does nothing.

adjust(grads_finite)#

Adjusts the loss scale.

Parameters:

grads_finite (jnp.ndarray) – Finite gradients array.

Returns:

StaticLossScale – The adjusted loss scale object.

scale(tree)#

Scale the input tree.

Parameters:

tree (T) – The input tree structure.

Return type:

TypeVar(T)

Returns:

Any – The scaled tree structure.

unscale(tree)#

Unscale the input tree.

Parameters:

tree (T) – The input tree structure.

Return type:

TypeVar(T)

Returns:

Any – The unscaled tree structure.

class neurai.mp.scale.StaticLossScale(loss_scale)#

Bases: object

Scales and unscales by a fixed constant.

Parameters:

loss_scale (Array) –

adjust(grads_finite)#

Adjusts the loss scale.

Parameters:

grads_finite (jnp.ndarray) – Finite gradients array.

Returns:

StaticLossScale – The adjusted loss scale object.

scale(tree)#

Scales the input tree by a fixed constant.

Parameters:

tree (T) – The input tree structure to be scaled.

Return type:

TypeVar(T)

Returns:

Any – The scaled tree structure.

unscale(tree)#

Unscales the input tree by the inverse of the fixed constant.

Parameters:

tree (T) – The input tree structure to be unscaled.

Return type:

TypeVar(T)

Returns:

Any – The unscaled tree structure.

neurai.mp.scale.all_finite(tree)#

Check if all elements in the tree are finite.

Parameters:
  • tree (T) –

  • structure. (The input tree) –

Return type:

Array

Returns:

jnp.ndarray – A boolean scalar indicating whether the input arrays are finite.

neurai.mp.scale.register_empty_pytree(cls)#

Register a custom Pytree node.

Parameters:

cls – The class to register as a Pytree node.

neurai.mp.scale.select_tree(pred, a, b)#

Selects elements from two trees based on a boolean scalar predicate.

Parameters:
  • pred (jnp.ndarray) – Boolean scalar indicating which elements to select.

  • a (T) – First tree structure.

  • b (T) – Second tree structure.

Return type:

TypeVar(T)

Returns:

Any – A tree structure containing elements selected from either a or b based on the predicate.

neurai.mp.scale.warn_if_not_floating(x, var_name)#

Produces a warning if the given array does not have a floating type.

This function handles an edgecase where Jax passes in an object() to determine the structure of user defined pytrees during compilation. They recommend explicitly checking if the array in question has the type object.

From the Jax documentation: “The __init__ and __new__ methods of custom PyTree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases.”

See: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

Parameters:
  • x (Union[jnp.ndarray, object]) – Any object.

  • var_name (str) – A useful name to put in error messages.

Raises:

TypeError – If floating type for {var_name} is different with got {x_dtype}.

Return type:

None

Module contents#