neurai.nn.layer package#

Submodules#

class neurai.nn.layer.activate.Acitivate(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

A base class for activation functions. This class inherits from Layer and defines a __call__ method that applies an activation function to the inputs.

Parameters:
  • activate_fun (Callable) – The activation function to use. This should be a callable object that takes an input tensor and returns the tensor with the applied activation function.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.activate.Celu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for celu activation function. This class inherits from Acitivate and uses jax.nn.celu as the default activation function.

Examples

from neurai.nn import Celu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Celu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.celu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(alpha=1.0)#

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

Parameters:
  • x (Any) – input array

  • alpha (Any) – array or scalar (default: 1.0)

Return type:

Any

class neurai.nn.layer.activate.Elu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for elu activation function. This class inherits from Acitivate and uses jax.nn.elu as the default activation function.

Examples

from neurai.nn import Elu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Elu()(x, alpha=0.5)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.elu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(alpha=1.0)#

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
Parameters:
  • x (Any) – input array

  • alpha (Any) – scalar or array of alpha values (default: 1.0)

Return type:

Any

class neurai.nn.layer.activate.Gelu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for gelu activation function. This class inherits from Acitivate and uses jax.nn.gelu as the default activation function.

Examples

from neurai.nn import Gelu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Gelu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.gelu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(approximate=True)#

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters:
  • x (Any) – input array

  • approximate (bool) – whether to use the approximate or exact formulation.

Return type:

Any

class neurai.nn.layer.activate.Glu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for glu activation function. This class inherits from Acitivate and uses jax.nn.glu as the default activation function.

Examples

from neurai.nn import Glu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32)
y = Glu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.glu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(axis=-1)#

Gated linear unit activation function.

Parameters:
  • x (Any) – input array

  • axis (int) – the axis along which the split should be computed (default: -1)

Return type:

Any

class neurai.nn.layer.activate.HardSigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for hard_sigmoid activation function. This class inherits from Acitivate and uses jax.nn.hard_sigmoid as the default activation function.

Examples

from neurai.nn import HardSigmoid
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32)
y = HardSigmoid()(x)

expected output:

Array([[0.33333334, 0.8333334 ], [1.        , 0.        ]], dtype=float32)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_sigmoid.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Hard Sigmoid activation function.

Computes the element-wise function

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.HardSilu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for hard_silu activation function. This class inherits from Acitivate and uses jax.nn.hard_silu as the default activation function.

Examples

from neurai.nn import HardSilu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = HardSilu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_silu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Hard SiLU activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.HardSwish(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for hard_swish activation function. This class inherits from Acitivate and uses jax.nn.hard_swish as the default activation function.

Examples

from neurai.nn import HardSwish
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = HardSwish()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_swish.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Hard SiLU activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.HardTanh(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for hard_tanh activation function. This class inherits from Acitivate and uses jax.nn.hard_tanh as the default activation function.

Examples

from neurai.nn import HardTanh
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = HardTanh()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_tanh.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Hard \(\mathrm{tanh}\) activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.LeakyRelu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for leaky relu activation function. This class inherits from Acitivate and uses jax.nn.leaky_relu as the default activation function.

Examples

from neurai.nn import LeakyRelu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = LeakyRelu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.leaky_relu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(negative_slope=0.01)#

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

Parameters:
  • x (Any) – input array

  • negative_slope (Any) – array or scalar specifying the negative slope (default: 0.01)

Return type:

Any

class neurai.nn.layer.activate.LogSigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for log_sigmoid activation function. This class inherits from Acitivate and uses jnp.log_sigmoid as the default activation function.

Examples

from neurai.nn import LogSigmoid
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = LogSigmoid()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.log_sigmoid.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.LogSoftmax(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for log_softmax activation function. This class inherits from Acitivate and uses jax.nn.log_softmax as the default activation function.

Examples

from neurai.nn import LogSoftmax
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = LogSoftmax()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.log_softmax.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(axis=-1, where=None, initial=None)#

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters:
  • x (Any) – input array

  • axis (Union[int, Tuple[int, ...], None]) – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.

  • where (Optional[Any]) – Elements to include in the log_softmax.

  • initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when where is not None.

Return type:

Any

class neurai.nn.layer.activate.LogSumexp(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for logsumexp activation function. This class inherits from Acitivate and uses jax.nn.logsumexp as the default activation function.

Examples

from neurai.nn import LogSumexp
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = LogSumexp()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.logsumexp.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(axis=None, b=None, keepdims=False, return_sign=False)#

Compute the log of the sum of exponentials of input elements.

LAX-backend implementation of scipy.special.logsumexp().

Original docstring below.

Parameters:
  • a (array_like) – Input array.

  • axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.

  • b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.

  • return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).

Return type:

Union[Array, Tuple[Array, Array]]

Returns:

  • res (ndarray) – The result, np.log(np.sum(np.exp(a))) calculated in a numerically more stable way. If b is given then np.log(np.sum(b*np.exp(a))) is returned.

  • sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.

class neurai.nn.layer.activate.Mish(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for mish activation function. This class inherits from Acitivate and uses mish as the default activation function.

Examples

from neurai.nn import Mish
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Mish()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is mish.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.activate.Relu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for relu activation function. This class inherits from Acitivate and uses jax.nn.relu as the default activation function.

Examples

from neurai.nn import Relu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Relu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.relu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun: Callable = <function relu>#
class neurai.nn.layer.activate.Relu6(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for relu6 activation function. This class inherits from Acitivate and uses jax.nn.relu6 as the default activation function.

Examples

from neurai.nn import Relu6
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Relu6()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.relu6.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun: Callable = <function relu6>#
class neurai.nn.layer.activate.Selu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for selu activation function. This class inherits from Acitivate and uses jax.nn.selu as the default activation function.

Examples

from neurai.nn import Selu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Selu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.selu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.Sigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for sigmoid activation function. This class inherits from Acitivate and uses jax.nn.sigmoid as the default activation function.

Examples

from neurai.nn import Sigmoid
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32)
y = Sigmoid()(x)

expect output:

Array([[0.26894143, 0.8807971 ], [0.95257413, 0.01798621]], dtype=float32)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.sigmoid.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.Silu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for silu activation function. This class inherits from Acitivate and uses jax.nn.silu as the default activation function.

Examples

from neurai.nn import Silu
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Silu()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.silu.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

SiLU activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.SoftSign(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for soft_sign activation function. This class inherits from Acitivate and uses jax.nn.soft_sign as the default activation function.

Examples

from neurai.nn import SoftSign
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = SoftSign()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.soft_sign.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Soft-sign activation function.

Computes the element-wise function

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.Softmax(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for softmax activation function. This class inherits from Acitivate and uses jax.nn.softmax as the default activation function.

Examples

from neurai.nn import Softmax
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Softmax()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.softmax.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun(axis=-1, where=None, initial=None)#

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters:
  • x (Any) – input array

  • axis (Union[int, Tuple[int, ...], None]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.

  • where (Optional[Any]) – Elements to include in the softmax.

  • initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when where is not None.

Return type:

Any

neurai.nn.layer.activate.Softmax2D#

alias of Softmax

class neurai.nn.layer.activate.Softmin(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for softmin activation function. This class inherits from Acitivate and uses softmin as the default activation function.

Examples

from neurai.nn import Softmin
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Softmin()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is softmin.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.activate.Softplus(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for softplus activation function. This class inherits from Acitivate and uses jax.nn.softplus as the default activation function.

Examples

from neurai.nn import Softplus
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Softplus()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.softplus.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
Parameters:

x (Any) – input array

Return type:

Any

class neurai.nn.layer.activate.Tanh(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Acitivate

A class for tanh activation function. This class inherits from Acitivate and uses jax.nn.tanh as the default activation function.

Examples

from neurai.nn import Tanh
from jax import numpy as jnp
x = jnp.array([[-1, 2], [3, -4]])
y = Tanh()(x)
Parameters:
  • activate_fun (Callable) – The activation function to use. Default is jax.nn.tanh.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

activate_fun()#

Compute hyperbolic tangent element-wise.

LAX-backend implementation of numpy.tanh().

Original docstring below.

Equivalent to np.sinh(x)/np.cosh(x) or -1j * np.tan(1j*x).

Parameters:

x (array_like) – Input array.

Returns:

y (ndarray) – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar.

References

class neurai.nn.layer.attention.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=VarianceScalingIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Class for an embedding layer in a neural network.

Examples

from neurai.nn import Embed
from jax import numpy as jnp

input = jnp.asarray([9, 6, 5, 7, 8, 8, 9, 2, 8])
embed = Embed(num_embeddings=10, features=12)
param = embed.init(input)
out = embed.run(param, input)
Parameters:
  • num_embeddings (int) – The size of the vocabulary or number of distinct items to embed.

  • features (int) – The dimensionality of the embedding.

  • dtype (Optional[Any], Optional) – The dtype of the input, by default None.

  • param_dtype (Any, Optional) – The dtype of the embedding parameters, by default jnp.float32.

  • embedding_init (Callable[..., jnp.ndarray], Optional) – The initializer for the embedding parameters, by default default_embed_init.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

attend(query)#

Attend over the embedding using a query array.

Parameters:

query (jnp.ndarray) – Array with the last dimension equal to the feature depth features of the embedding.

Return type:

Array

Returns:

jnp.ndarray – An array with the final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.

Raises:

TypeError – If the input query and the embedding have different data types.

neurai.nn.layer.attention.canonicalize_dtype(*args, dtype=None, inexact=True)#

Canonicalize an optional dtype to the definitive dtype.

If the dtype is None, this function will infer the dtype from the input arguments using jnp.result_type(). If dtype is not None, it will be returned unmodified, or an exception will be raised if the dtype is invalid.

Parameters:
  • *args (JAX array compatible values) – Input values to infer the dtype from. None values are ignored.

  • dtype (Optional dtype override) – If specified, the input arguments are cast to the specified dtype instead, and dtype inference is disabled.

  • inexact (bool) – When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’t work directly on integers like taking a mean for example.

Return type:

Any

Returns:

The dtype that *args should be cast to.

neurai.nn.layer.attention.masked_softmax(X, valid_lens)#

Perform softmax operation by masking elements on the last axis.

Parameters:
  • X (np.ndarray) – The input 3D tensor to be softmaxed.

  • valid_lens (np.ndarray) – The 1D or 2D tensor containing the valid length for each sequence.

Returns:

np.ndarray – The softmaxed tensor.

neurai.nn.layer.attention.promote_dtype(*args, dtype=None, inexact=True)#

Promotes input arguments to a specified or inferred dtype.

All args are cast to the same dtype. See canonicalize_dtype for how this dtype is determined.

The behavior of promote_dtype is mostly a convenience wrapper around neurai.numpy.promote_types. The differences being that it automatically casts all input to the inferred dtypes, allows inference to be overridden by a forced dtype, and has an optional check to guarantee the resulting dtype is inexact.

Parameters:
  • *args (Tuple) – JAX array compatible values. None values are returned as is.

  • dtype (Optional dtype override.) – If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.

  • inexact (bool) – When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’t work directly on integers like taking a mean for example.

Return type:

List[Array]

Returns:

List ([jnp.ndarray]) – The arguments cast to arrays of the same dtype.

neurai.nn.layer.attention.sequence_mask(X, valid_len, value=0)#

Create a mask tensor for a given sequence length.

Parameters:
  • X (np.ndarray) – The input tensor to be masked.

  • valid_len (np.ndarray) – The 1D or 2D tensor containing the valid length for each sequence.

  • value (int, Optional) – The value to fill the masked elements with. Default is 0.

Returns:

np.ndarray – The masked tensor.

neurai.nn.layer.attention.transpose_output(X, num_heads)#

Transpose the output tensor after parallel computation of multiple attention heads.

Parameters:
  • X (np.ndarray) – The input tensor to be transposed.

  • num_heads (int) – The number of attention heads.

Returns:

np.ndarray – The transposed tensor.

neurai.nn.layer.attention.transpose_qkv(X, num_heads)#

Transposition for parallel computation of multiple attention heads.

Parameters:
  • X (np.ndarray) – The input tensor to be transposed.

  • num_heads (int) – The number of attention heads.

Returns:

np.ndarray – The transposed tensor.

class neurai.nn.layer.conv.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Applies a convolution to the inputs.

Parameters:
  • features (int) – Number of convolution filters (output channels).

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

check_kernel_dimensions(kernel, size)#

Check if the kernel dimensions match the expected size.

Parameters:
  • kernel (tuple, list) – The kernel dimensions.

  • size (int) – The expected size of the kernel.

Raises:

ValueError – If kernel does not have the expected size.

class neurai.nn.layer.conv.Conv1d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv

1-dimensional convolution layer.

Examples

from neurai.nn import Conv1d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 5, 1))
conv = Conv1d(features=4, kernel_size=(3,))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – Number of convolution filters (output channels).

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Raises:

ValueError – If kernel_size is not a single integer or a sequence of a single integer.

class neurai.nn.layer.conv.Conv2d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv

Convolutional layer for 2D inputs.

Examples

from neurai.nn import Conv2d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 1))
conv = Conv2d(3, kernel_size=(3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – Number of convolution filters (output channels).

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Raises:

ValueError – If kernel_size is not an integer or sequence of 2 integers.

class neurai.nn.layer.conv.Conv3d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv

3D convolution layer.

Examples

from neurai.nn import Conv3d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1))
conv = Conv3d(3, kernel_size=(3, 3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – Number of convolution filters (output channels).

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Raises:

ValueError – If kernel_size is not an integer or sequence of 3 integers.

class neurai.nn.layer.conv.ConvTranspose(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Applies a transpose convolution to the inputs.

Parameters:
  • features (int) – The number of output channels (i.e., the number of filters/kernels).

  • kernel_size (Sequence[int]) – The dimensions of the convolutional kernel (e.g., [3, 3] for a 2D kernel).

  • strides (Union[None, int, Sequence[int]], Optional) – The stride of the convolution. Controls the step size at which the kernel is applied to the input. Default is 1, which means no overlapping.

  • padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.

  • output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.

  • use_bias (bool, Optional) – Whether to include a bias term in the convolution. Default is True.

  • rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.

  • kernel_init (Callable, Optional) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution initializer.

  • bias_init (Callable, Optional) – The initializer function for the bias term (if use_bias=True). Default is a normal distribution initializer.

  • mask (Optional[jnp.ndarray], Optional) – A mask for the kernel. If provided, it will be element-wise multiplied with the kernel matrix. Default is None, meaning no masking is applied.

  • precision (Optional[lax.Precision], Optional) – The numerical precision of the computation, used for high-precision calculations. Default is None, which uses the default precision.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

__call__(self, input: jnp.ndarray) jnp.ndarray:#

Performs the transpose convolution operation on the input.

Raises:
  • TypeError – If kernel_size is an integer instead of a tuple/list of integers.

  • ValueError – If the shape of the mask does not match the shape of the weights (kernel).

Parameters:
class neurai.nn.layer.conv.ConvTranspose1d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: ConvTranspose

One-dimensional transposed convolution layer.

Examples

from neurai.nn import ConvTranspose1d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 5, 4))
conv = ConvTranspose1d(features=10, kernel_size=(1,))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – The number of output channels.

  • kernel_size (Union[int, Tuple[int]]) – The dimensions of the convolutional kernel.

  • strides (Union[None, int, Tuple[int]]) – The stride of the convolution. Default is 1.

  • padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.

  • output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.

  • use_bias (bool) – Whether to include a bias term. Default is True.

  • rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.

  • kernel_init (Callable) – The initializer function for the convolutional kernel. Default is a normal distribution.

  • bias_init (Callable) – The initializer function for the bias term. Default is a normal distribution.

  • mask (Optional[jnp.ndarray]) – A mask for the kernel. Default is None.

  • precision (Optional[jax.lax.Precision]) – The numerical precision of the computation. Default is None.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

  • out_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]]) –

  • param_dtype (Any) –

Raises:

ValueError – If kernel_size is not an integer or a tuple of integers.

class neurai.nn.layer.conv.ConvTranspose2d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: ConvTranspose

2D convolution transpose layer.

Examples

from neurai.nn import ConvTranspose2d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 1))
conv = ConvTranspose2d(features=3, kernel_size=(3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – The number of output channels.

  • kernel_size (Union[int, Tuple[int]]) – The dimensions of the convolutional kernel. If an int is provided, the kernel will have equal dimensions.

  • strides (Union[None, int, Tuple[int]]) – The stride of the convolution. Default is 1, which means no overlapping.

  • padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.

  • output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.

  • use_bias (bool) – Whether to include a bias term. Default is True.

  • rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.

  • kernel_init (Callable) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution.

  • bias_init (Callable) – The initializer function for the bias term. Default is a normal distribution.

  • mask (Optional[jnp.ndarray]) – A mask for the kernel. Default is None, meaning no masking is applied.

  • precision (Optional[jax.lax.Precision]) – The numerical precision of the computation. Default is None, which uses the default precision.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

setup(self):

Check and set up the configuration for the convolution transpose layer.

Raises:

ValueError – If kernel_size is not an integer or a tuple of integers.

Parameters:
class neurai.nn.layer.conv.ConvTranspose3d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: ConvTranspose

Convolutional transpose layer for 3D inputs.

Examples

from neurai.nn import ConvTranspose3d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1))
conv = ConvTranspose3d(features=3, kernel_size=(3, 3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – The number of output channels (i.e., the number of filters/kernels).

  • kernel_size (Union[int, Tuple[int], List[int]]) – The dimensions of the convolutional kernel (e.g., [3, 3, 3] for a 3D kernel). If a single integer is provided, it will be replicated for all three dimensions.

  • strides (Union[None, int, Tuple[int], List[int]]) – The stride of the convolution. Controls the step size at which the kernel is applied to the input. If a single integer is provided, it will be replicated for all three dimensions. Default is 1, which means no overlapping.

  • padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.

  • output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.

  • use_bias (bool) – Whether to include a bias term in the convolution. Default is True.

  • rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.

  • kernel_init (Callable) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution initializer.

  • bias_init (Callable) – The initializer function for the bias term (if use_bias=True). Default is a normal distribution initializer.

  • mask (Optional[jnp.ndarray]) – A mask for the kernel. If provided, it will be element-wise multiplied with the kernel matrix. Default is None, meaning no masking is applied.

  • precision (Optional[lax.Precision]) – The numerical precision of the computation, used for high-precision calculations. Default is None, which uses the default precision.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

setup(self):

Check and set up the configuration for the convolution transpose layer.

Raises:

ValueError – If kernel_size is not an integer or a tuple of integers.

Parameters:
class neurai.nn.layer.conv.SNNConv2d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv2d

2D convolutional layer for SNN (Spiking Neural Network) inputs. This SNNConv2d convolution accepts 5-dimensional or 4-dimensional inputs; When the input data is 5-dimensional, merge the Batch and Depth dimensions and use 2-dimensional convolution for calculation; The calculation process of 4-dimensional input data is consistent with ANN.

Examples

from neurai.nn import SNNConv2d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1))
conv = SNNConv2d(features=3, kernel_size=(3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – Number of output channels.

  • weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.

  • kernel_init (Callable, Optional) – Initialization function for the kernel parameters. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • use_bias (bool, Optional) – Whether to include bias in the layer. Default is False.

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution withstride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.conv.SNNConv3d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv3d

3D convolutional layer for SNN (Spiking Neural Network) inputs.

Examples

from neurai.nn import SNNConv3d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1))
conv = SNNConv3d(features=3, kernel_size=(3, 3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
  • features (int) – Number of output channels.

  • weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.

  • kernel_init (Callable, Optional) – Initialization function for the kernel parameters. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • use_bias (bool, Optional) – Whether to include bias in the layer. Default is False.

  • kernel_size (Sequence[int]) – The shape of the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution withstride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • kernel_init – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.conv.SNNConvTranspose3d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=False, rhs_dilation=1, kernel_init=KaimingUniformIniter(key=None), bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: ConvTranspose3d

3D convolutional layer for SNN (Spiking Neural Network) inputs.

Examples

from neurai.nn import SNNConv3d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1))
conv = SNNConv3d(features=3, kernel_size=(3, 3, 3))
param = conv.init(input)
out = conv.run(param, input)
Parameters:
neurai.nn.layer.conv.canonicalize_padding(padding, rank)#

Canonicalizes conv padding to a jax.lax supported format.

Parameters:
  • padding (Union[str, int, Sequence[Union[int, Tuple[int, int]]]]) – The padding to be canonicalized.

  • rank (int) – The rank of the conv layer.

Returns:

Union[str, Sequence[Tuple[int, int]]] – The canonicalized padding.

Raises:

ValueError – If padding has an invalid format.

neurai.nn.layer.conv.maybe_replicate(x, rep)#

Replicate an integer or a sequence of integers rep times.

Parameters:
  • x (Optional[Union[int, Sequence[int]]]) – The integer or sequence of integers to be replicated.

  • rep (int) – The number of times to replicate x.

Returns:

Tuple[int] – The replicated integer or sequence of integers.

class neurai.nn.layer.dropout.Dropout(rate, deterministic=None, rng_col='dropout', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

A layer that stochastically ignores a subset of inputs each training step.

In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).

This layer is active only during training (mode=brainpy.modes.training). In other circumstances, it is a no-op.

Examples

from neurai.nn import Dropout
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
dropout = Dropout(rate=0.5, deterministic=False)
param = dropout.init(input, rngs={"dropout": 1})
out = dropout.run(param, input, rngs={"dropout": 1})
Parameters:
  • rate (float) – The dropout probability.

  • deterministic (Optional[bool], Optional) – If False, the inputs are scaled by 1 / (1 - rate) and masked. If True, no mask is applied, and the inputs are returned as is. Default is None.

  • rng_col (str, Optional) – The rng collection name to use when requesting an rng key. Default is ‘dropout’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Returns:

The masked inputs.

class neurai.nn.layer.dropout.SNNDropout3d(rate, deterministic=None, rng_col='dropout', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Dropout

3D dropout layer for SNN (Spiking Neural Network) inputs. Dropout is applied to the input, with dropout over the time dimension being preserved,

meaning that if a neuron is discarded, it remains excluded throughout the entire duration of time.

Examples

from neurai.nn import SNNDropout3d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
dropout = SNNDropout3d(rate=0.5, deterministic=False)
param = dropout.init(input, rngs={"dropout": 1})
out = dropout.run(param, input, rngs={"dropout": 1})
Parameters:
  • rate (float) – The dropout probability.

  • deterministic (Optional[bool], Optional) – If False, the inputs are scaled by 1 / (1 - rate) and masked. If True, no mask is applied, and the inputs are returned as is. Default is None.

  • rng_col (str, Optional) – The rng collection name to use when requesting an rng key. Default is ‘dropout’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Returns:

The masked inputs.

class neurai.nn.layer.exodus.Exodus(size=1, pid=0, V_th=10.0, V_reset=0.0, stdp_tau=20.0, tau_triplet=110.0, dendritic_delay=1, area_name=None, surrogate_grad_fn=<neurai.grads.surrogate_grad.SingleExponential object>, id=0, neuron_start_id=1, record_v=False, step_mode=StepMode.MULTI, reset_type=ResetType.HARD, time_recurrent=False, before_reset_mem=None, param_dtype=<class 'jax.numpy.float32'>, V_rest=0.0, R=1.0, tau=20.0, min_v_mem=0.0, norm_input=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: SNNLIF

Exodus implementation of a spiking Leaky Integrate and Fire neuron. Does not simulate synaptic dynamics.

EXact computation Of Derivativesas Update to SLAYER. Paper work: [2205.10242v1] EXODUS: Stable and Efficient Training of Spiking Neural Networks (https://arxiv.org/abs/2205.10242v1). Code work: synsense/sinabs-exodus.

Examples

..code-block:: python

from neurai.nn import Exodus from jax import random exodus = Exodus(V_rest=0., V_reset=0., V_th=1.0, tau=1.0, record_v=False, min_v_mem=0) input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 4)) param = exodus.init(input) out = exodus.run(param, input)

Parameters:
  • V_rest (float, Optional) – The initial value of membrane potential. Default is 0.

  • V_reset (float, Optional) – The reset membrane potential. Default is 0.

  • V_th (float, Optional) – The firing threshold. Default is 1.0.

  • tau (Union[float, jnp.ndarray], Optional) – The membrane potential time constant.

  • record_v (bool, Optional) – Whether to record the membrane potential values. Setting it to True will output both spike and membrane potential values, while setting it to False will only output spike values. Default is False.

  • surrogate_grad_fn (Callable, Optional) – The function for calculating surrogate gradients of the heaviside step function in backward. Default is surrogate_grad.SingleExponential().

  • min_v_mem (float or None, Optional) – Lower bound for the membrane potential v_mem, clipped at every time step. Default is None.

  • norm_input (bool, Optional) – If True, will normalize the inputs by tau. This helps when training time constants. Default is False.

  • max_num_spikes_per_bin (int, Optional) – Maximum number of neurons that a neuron can emit per time step.

  • activations (int, Optional) – Activations from the previous time step. Has to be contiguous. Default is 1.

  • step_mode (neurai.const.StepMode, Optional) – SNN LIF Neuron update mode, neurai.const.single for single-step update or neurai.const.multi for multi-step update. In single-step mode, the update process includes calculations for only one time step, whil in multi-step mode, it calculates the update process for T time steps. Default is neurai.const.multi.

  • R (float) – membrane resistance, by default 1.

  • reset_type (Union[None, neurai.const.ResetType]) – How to reset the membrane potential after the spike, with values of neurai.const.hard or neurai.const.soft. The default is neurai.const.hard.

  • time_recurrent (bool) – The time_recurrent parameter is used to verify the calculation mode of the LIF model in the SRNN structure.

  • before_reset_mem (bool) – A Boolean value used to determine whether the membrane potential is limited below the threshold after a neuron firing spikes.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • size (int, Optional) – The number of neurons in the population, by default 1.

  • stdp_tau (float = 20.) –

  • tau_triplet (float, Optional) – Time constant of long presynaptic trace, by default 110.

  • dendritic_delay (Union[int, jax.numpy.ndarray, Initializer, Callable], Optional) – The dendritic delay length, by default 1

  • pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.

  • area_name (str, Optional) – The name of the area to which the current neuron belongs. By default empty string.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

  • id (int) –

  • neuron_start_id (int) –

Raises:
  • ValueError – If the provided input shape is empty. If V_th is not greater than min_v_mem (if min_v_mem is specified).

  • Exception – If in_features or out_features have more than 3 dimensions (for the dense method).

get_alpha(template, tau)#

Calculates alpha from tau.

Parameters:
  • template (jnp.Array) – Specify the shape of the alpha.

  • tau (float) – Membrane potential time constant.

Returns:

jnp.Array

init_neuron_state(input_shape)#

Initialize mem and spike for a single time step. theta is a scalar of data type float.

Parameters:

input_shape (jnp.ndarray) –

neurai.nn.layer.exodus.exodus_cpu_bwd(res, posterior_grad)#

The number of posteriorGrads is consistent with the number of return values for forward. If the return value of forward is greater than 1, posteriorGrad is a tuple. Backward pass for Exodus CPU forward computation.

Parameters:
  • res (Tuple) – Output of the Exodus CPU forward computation. It contains the following elements:

  • posterior_grad (Tuple) – posterior_grad[0], jnp.ndarray. The gradient of the loss with respect to the output spike tensor of shape (neurons, Timestep). posterior_grad[1], jnp.ndarray. The gradient of the loss with respect to the output membrane potential tensor of shape (neurons, Timestep).

Returns:

Tuple – The gradients of the loss with respect to the inputs of the Exodus CPU forward computation. It contains the following elements: input_grad : jnp.ndarray. The gradient of the loss with respect to the input spike tensor of shape (neurons, Timestep). init_grad : jnp.ndarray. The gradient of the loss with respect to the input membrane potential tensor of shape (neurons, Timestep). alpha_grad : jnp.ndarray. The gradient of the loss with respect to the input alpha tensor of shape (neurons,). None * 4 : Elements with None to maintain consistency with the number of inputs in the forward computation.

neurai.nn.layer.exodus.exodus_cpu_fwd(input_spike, mem_initial, alpha, membrane_subtract, threshold, theta_low, spikes_num)#

Forward implementation of Exodus CPU function.

Parameters:
  • input_spike (jnp.ndarray) – Input spike tensor of shape (neurons, Timestep) representing the spikes received by each neuron at each timestep.

  • mem_initial (jnp.ndarray) – Initial membrane potential tensor of shape (neurons, Timestep) representing the starting membrane potential of each neuron.

  • alpha (jnp.ndarray) – Alpha tensor of shape (neurons, Timestep) representing the decay factor for each neuron’s membrane potential.

  • membrane_subtract (jnp.ndarray) – Tensor of shape (neurons, Timestep) representing the value to be subtracted from the membrane potential due to spikes.

  • threshold (float) – The threshold value for neuron firing.

  • theta_low (float) – Lower threshold for the membrane potential.

  • spikes_num (int) – The maximum number of spikes a neuron can emit per time step.

Returns:

  • Tuple

  • output_spike (jnp.ndarray) – The output spike tensor of shape (neurons, Timestep) representing the spikes emitted by each neuron at each timestep.

  • output_mem (jnp.ndarray) – The post-synaptic potential tensor of shape (neurons, Timestep) representing the membrane potential of each neuron after spike processing.

  • input_tuple (Tuple) – A tuple containing the input values needed for the backward pass during gradient computation.

class neurai.nn.layer.linear.Flatten(start_dim=1, end_dim=-1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Initializes a Flatten module that flattens a jnp.ndarray along the specified dimensions.

Examples

from neurai.nn import Flatten
from jax import random

flatten = Flatten()
input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
param = flatten.init(input)
out = flatten.run(param, input)
Parameters:
  • start_dim (int, Optional) – The starting dimension to flatten, by default 1.

  • end_dim (int, Optional) – The ending dimension to flatten, by default -1. If -1, all dimensions starting from start_dim are flattened.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.linear.Linear(features, bias=True, param_dtype=<class 'jax.numpy.float32'>, w_initializer=KaimingUniformIniter(key=None), b_initializer=UniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

A linear transformation Layer.

This Layer applies a linear transformation to the input data, optionally adding a bias vector.

Examples

from neurai.nn import Linear
from jax import random

linear = Linear(10)
input = random.normal(random.PRNGKey(0), (1, 784))
param = linear.init(input)
out = linear.run(param, input)
Parameters:
  • features (int) – The number of output features.

  • bias (bool, Optional) – Whether to use a bias vector. Default is True.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers (default: float32).

  • w_initializer (Callable, Optional) – The initializer for the weight matrix. Default is KaimingUniformIniter().

  • b_initializer (Optional[Callable], Optional) – The initializer for the bias vector. Default is UniformIniter(-0.08, 0.08).

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.linear.SNNLinear3d(features, kernel_size=None, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv3d

3D linear layer for SNN (Spiking Neural Network) inputs.

Examples

from neurai.nn import SNNLinear3d
from jax import random

snn_linear = SNNLinear3d(features=10)
input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 5))
param = snn_linear.init(input)
out = snn_linear.run(param, input)
Parameters:
class neurai.nn.layer.linear.Sequential(layers, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

A Layer that applies a sequence of other Layers.

This Layer takes a variable number of other Layers as arguments, and applies them one by one to the input data.

Examples

from neurai.nn import Sequential, Linear
from jax import random

model = Sequential([
  Linear(1024),
  Linear(512),
  Linear(10)])
input = random.normal(random.PRNGKey(0), (1, 784))
param = model.init(input)
out = model.run(param, input)
print(out.shape)
Parameters:
  • layers (Sequence[Callable[..., Any]]) – The Layers to run sequentially.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.BatchNorm(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Batch Normalization layer.

Parameters:
  • use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis (int) – The batch axis of the input.

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5

  • affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True

  • param_dtype (Any) – The dtype passed to parameter initializers (default: float32).

  • axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

  • scale_init (Callable) – An initializer generating the original scaling matrix.

  • bias_init (Callable) – An initializer generating the original translation matrix.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.BatchNorm1d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: BatchNorm

1D Batch Normalization layer.

Examples

from neurai.nn import BatchNorm1d
from jax import random

bn1 = BatchNorm1d(use_running_average=False)
input = random.normal(random.PRNGKey(0), (10, 20))
param = bn1.init(input)
out = bn1.run(param, input)
Parameters:
  • use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis (int) – The batch axis of the input.

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5

  • affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True

  • param_dtype (Any) – The dtype passed to parameter initializers (default: float32).

  • axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

  • scale_init (Callable) – An initializer generating the original scaling matrix.

  • bias_init (Callable) – An initializer generating the original translation matrix.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.BatchNorm2d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: BatchNorm

2D Batch Normalization layer.

Examples

from neurai.nn import BatchNorm2d
from jax import random

bn2 = BatchNorm2d(use_running_average=False)
input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
param = bn2.init(input)
out = bn2.run(param, input)
Parameters:
  • use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis (int) – The batch axis of the input.

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5

  • affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True

  • param_dtype (Any) – The dtype passed to parameter initializers (default: float32).

  • axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

  • scale_init (Callable) – An initializer generating the original scaling matrix.

  • bias_init (Callable) – An initializer generating the original translation matrix.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.BatchNorm3d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: BatchNorm

3D Batch Normalization layer.

Examples

from neurai.nn import BatchNorm3d
from jax import random

bn3 = BatchNorm3d(use_running_average=False)
input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 16))
param = bn3.init(input)
out = bn3.run(param, input)
Parameters:
  • use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis (int) – The batch axis of the input.

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5

  • affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True

  • param_dtype (Any) – The dtype passed to parameter initializers (default: float32).

  • axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

  • scale_init (Callable) – An initializer generating the original scaling matrix.

  • bias_init (Callable) – An initializer generating the original translation matrix.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.LayerNorm(epsilon=1e-06, param_dtype=<class 'jax.numpy.float32'>, affine=True, bias_init=<function LayerNorm.<lambda>>, scale_init=<function LayerNorm.<lambda>>, reduction_axes=-1, feature_axes=-1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Examples

from neurai.nn import LayerNorm
from jax import random

layer_norm = LayerNorm()
input = random.normal(random.PRNGKey(0), (10, 28, 28, 1))
param = layer_norm.init(input)
out = layer_norm.run(param, input)
Parameters:
  • epsilon (float, Optional) – A small float added to variance to avoid dividing by zero. Default: 1e-6.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers (default: float32).

  • affine (bool, Optional) – A boolean value that when set to True, this Layer has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: True.

  • bias_init (Callable, Optional) – Initializer for bias. Default: lambda _, shape, dtype: jnp.zeros(shape, dtype).

  • scale_init (Callable, Optional) – Initializer for scale. Default: lambda _, shape, dtype: jnp.ones(shape, dtype).

  • reduction_axes (Union[int, Any], Optional) – Axes for computing normalization statistics. Default: -1.

  • feature_axes (Union[int, Any], Optional) – Feature axes for learned bias and scaling. Default: -1.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.TdBatchNorm(use_running_average=None, axis=-2, epsilon=1e-05, momentum=0.1, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, alpha=1, v_th=0.2, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: BatchNorm2d

Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.

Examples

from neurai.nn import TdBatchNorm
from jax import random

tdbn = TdBatchNorm(use_running_average=False)
input = random.normal(random.PRNGKey(0), (10, 28, 28, 3, 3))
param = tdbn.init(input)
out = tdbn.run(param, input)
Parameters:
  • use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis (int) – The batch axis of the input.

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5

  • affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True

  • param_dtype (Any) – The dtype passed to parameter initializers (default: float32).

  • v_th (float) – membrane potential threshold, by default 0.2

  • alpha (float) – an addtional parameter which may change in resblock.

  • axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

  • scale_init (Callable) – An initializer generating the original scaling matrix.

  • bias_init (Callable) – An initializer generating the original translation matrix.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.normalization.TdLayer(layer, bn=None, use_running_average=None, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Converts a common layer to the time domain. The input tensor needs to have an additional time dimension,

which in this case is on the last dimension of the data. During the forward pass, a normal layer is performed for each time step of the data in that time dimension. Link to related paper: https://arxiv.org/pdf/2011.05280.

Examples

from neurai.nn import Conv2d, TdBatchNorm, TdLayer
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 10))
model = TdLayer(Conv2d(3, kernel_size=(3, 3)),
                TdBatchNorm(), use_running_average=False)
param = model.init(input)
out = model.run(param, input)
Parameters:
  • layer (Module) – The layer needs to be converted.

  • bn (Module) – If Batch Normalization (BN) is needed, the BN layer should be passed in together as a parameter.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

  • use_running_average (Optional[bool]) –

class neurai.nn.layer.normalization.WeightNorm(module, module_name='weight', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Weight normalization ( https://arxiv.org/abs/1602.07868).

Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by module_name (e.g. 'weight') with two parameters: one specifying the magnitude (e.g. 'weight_g') and one specifying the direction (e.g. 'weight_v').

Examples

from neurai.nn import WeightNorm, Conv2d
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
wn = WeightNorm(Conv2d(3, kernel_size=(3, 3)))
param = wn.init(input)
out = wn.run(param, input)
Parameters:
  • module (Module) – The containing module.

  • module_name (str, Optional) – The name of weight parameter, by default is ‘weight’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

static norm_except_dim(v, pow)#

Calculate the L-norm of v. :type v: Array :param v: The matrix for which the norm needs to be calculated. :type v: jnp.ndarray :type pow: int :param pow: The pow represents a few norms :type pow: int

Return type:

Array

Returns:

jnp.ndarray – The result of the norm.

static weight_norm(g, v)#

Computing weight param by weight_g and weight_v.

Parameters:
  • g (jnp.ndarray) – The weight_g.

  • v (jnp.ndarray) – The weight_v.

Return type:

Array

Returns:

jnp.ndarray – The calculated weight param.

class neurai.nn.layer.padding.ConstantPad(padding, constant_values=0.0, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Constant padding layer for neural networks.

Examples

from neurai.nn import ConstantPad
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pad = ConstantPad(padding=1, constant_values=0.5) # pad = ConstantPad(padding=((2, 2), (2, 2)), constant_values=0.3)
param = pad.init(input)
out = pad.run(param, input)
Parameters:
  • padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.

  • constant_values (float) – The constant value of the input tensor boundaries.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.padding.ReflectionPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Reflection padding layer for neural networks.

Examples

from neurai.nn import ReflectionPad
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pad = ReflectionPad(padding=2) # pad = ReflectionPad(padding=((2, 2), (2, 2)))
param = pad.init(input)
out = pad.run(param, input)
Parameters:
  • padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.padding.ReplicationPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Replication padding layer for neural networks.

Examples

from neurai.nn import ReplicationPad
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pad = ReplicationPad(padding=2) # pad = ReplicationPad(padding=((2, 2), (2, 2)))
param = pad.init(input)
out = pad.run(param, input)
Parameters:
  • padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.padding.ZeroPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Zero padding layer for neural networks.

Examples

from neurai.nn import ZeroPad
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pad = ZeroPad(padding=1) # pad = ZeroPad(padding=((1, 1), (1, 1)))
param = pad.init(input)
out = pad.run(param, input)
print(out.shape)
Parameters:
  • padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.pooling.AvgPool(kernel_size, strides=None, padding='VALID', count_include_pad=True, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Pool

Pooling layer for neural networks.

Examples

from neurai.nn import AvgPool
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pool = AvgPool(kernel_size=(2, 2))
param = pool.init(input)
out = pool.run(param, input)
Parameters:
  • init_val (scalar) – Initial value for the reduction.

  • func (callable) – Reduction function to use.

  • kernel_size (Union[int, Sequence[int]]) – Window size for the pooling operation.

  • strides (Union[None, int, Sequence[int]]) – Strides for the pooling operation.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Padding for the pooling operation.

  • count_include_pad (a boolean whether to include padded tokens) – in the average calculation (default: True).

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

func(y)#

Elementwise addition: \(x + y\).

Parameters:
Return type:

Array

class neurai.nn.layer.pooling.MaxPool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Pool

Pools the input by taking the maximum over a window.

Examples

from neurai.nn import MaxPool
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pool = MaxPool(kernel_size=(2, 1))
param = pool.init(input)
out = pool.run(param, input)
Parameters:
  • kernel_size (int, sequence of int) – An integer, or a sequence of integers defining the window to reduce over.

  • strides (int, sequence of int) – An integer, or a sequence of integers, representing the inter-window strides (default: (1, …, 1)).

  • padding (str, sequence of tuple) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to run before and after each spatial dimension.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

func(y)#

Elementwise maximum: \(\mathrm{max}(x, y)\)

For complex numbers, uses a lexicographic comparison on the (real, imaginary) pairs.

Parameters:
Return type:

Array

class neurai.nn.layer.pooling.MinPool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Pool

Pools the input by taking the average over a window.

Examples

from neurai.nn import MinPool
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3))
pool = MinPool(kernel_size=(1, 2))
param = pool.init(input)
out = pool.run(param, input)
Parameters:
  • kernel_size (int, sequence of int) – An integer, or a sequence of integers defining the window to reduce over.

  • strides (int, sequence of int) – An integer, or a sequence of integers, representing the inter-window strides (default: (1, …, 1)).

  • padding (str, sequence of tuple) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to run before and after each spatial dimension.

  • count_include_pad (a boolean whether to include padded tokens) – in the average calculation (default: True).

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

func(y)#

Elementwise minimum: \(\mathrm{min}(x, y)\)

For complex numbers, uses a lexicographic comparison on the (real, imaginary) pairs.

Parameters:
Return type:

Array

class neurai.nn.layer.pooling.Pool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Pooling layer for neural networks.

Parameters:
  • init_val (scalar) – Initial value for the reduction.

  • func (callable) – Reduction function to use.

  • kernel_size (Union[int, Sequence[int]]) – Window size for the pooling operation.

  • strides (Union[None, int, Sequence[int]]) – Strides for the pooling operation.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Padding for the pooling operation.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.pooling.SNNPool(features=1, kernel_size=(2, 2, 1), strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function SNNPool.<lambda>>, bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Conv3d

SNN pooling layer.

Examples

from neurai.nn import SNNPool
from jax import random

input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 6))
pool = SNNPool(kernel_size=(2, 2, 1))
param = pool.init(input)
out = pool.run(param, input)
Parameters:
  • features (int.) – Number of output channels.

  • kernel_size (int, sequence of int) – The shape of the convolutional kernel.

  • kernel_col (str.) – Specify whether weight can be learned. The default is const, indicating that it is not learnable.

  • kernel_init (Callable, Optional) – The initializer for the convolutional kernel.

  • strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.

  • padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.

  • input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.

  • kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.

  • feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.

  • use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.

  • mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.

  • param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.

  • precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.

  • bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

bias_init(shape, dtype=<class 'int'>, func=<function <lambda>>)#

Initializes the VarianceScalingIniter.

Parameters:
  • scale (float) – the value to scale the variance.

  • mode (str) – indicates how to calculate the variance scaling factor.

  • distribution (str) – indicates the type of distribution to use.

  • in_axis (int) – indicates the input axis for computing the variance scaling factor, by default -2.

  • out_axis (int) – indicates the output axis for computing the variance scaling factor, by default -1.

kernel_init(shape, dtype)#

Initializes the VarianceScalingIniter.

Parameters:
  • scale (float) – the value to scale the variance.

  • mode (str) – indicates how to calculate the variance scaling factor.

  • distribution (str) – indicates the type of distribution to use.

  • in_axis (int) – indicates the input axis for computing the variance scaling factor, by default -2.

  • out_axis (int) – indicates the output axis for computing the variance scaling factor, by default -1.

class neurai.nn.layer.pooling.UpSampleNearest(size=None, scale_factor=None, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Applies a nearest neighbor upsampling to an input signal composed of several input channels.

Examples

from neurai.nn import UpSampleNearest
from jax import random

input = random.normal(random.PRNGKey(0), (1, 12, 1))
pool = UpSampleNearest(size=(1, 12, 1))
param = pool.init(input)
out = pool.run(param, input)
Parameters:
  • size (Union[int, Tuple[int, int], Tuple[int, int, int]]) – Output spatial sizes.

  • scale_factor (Union[int, Tuple[int, int], Tuple[int, int, int]]) – Multiplier for spatial size. Has to match input size if it is a tuple. The dimension of this value does not include batch size and channel.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.recurrent.ALIFCell(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: SRNNCellBase

The neuron is SNNALIF.

Parameters:
  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.recurrent.GRUCell(input_size, hidden_size, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: RNNCellBase

GRU cell.

Examples

from neurai.nn import RNN, GRUCell
from jax import random

input = random.normal(random.PRNGKey(0), (28, 512, 28))
rnn = RNN(GRUCell(input_size=28, hidden_size=128))
param = rnn.init(input)
out = rnn.run(param, input)
Parameters:
  • input_size (int) – input size

  • hidden_size (int) – hidden layer size

  • kernel_init (Callable) – initializer function for the kernels that transform the input, by default KaimingUniformIniter()

  • recurrent_kernel_init (Callable) – initializer function for the kernels that transform the hidden state, by default KaimingUniformIniter()

  • bias_init (Callable) – initializer for the bias parameters

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

initial_state(batch_size)#

Constructs an initial state for this cell.

Parameters:

batch_size (Optional[int]) – a batch dimension

class neurai.nn.layer.recurrent.LIFCell(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: SRNNCellBase

The neuron is SNNLIF.

Parameters:
  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.recurrent.LSTMCell(input_size, hidden_size, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: RNNCellBase

LSTM cell.

Examples

from neurai.nn import RNN, LSTMCell
from jax import random

input = random.normal(random.PRNGKey(0), (28, 512, 28))
lstm = RNN(LSTMCell(input_size=28, hidden_size=128))
param = lstm.init(input)
out = lstm.run(param, input)
Parameters:
  • input_size (int) – input size

  • hidden_size (int) – hidden layer size

  • kernel_init (Callable) – initializer function for the kernels that transform the input, by default KaimingUniformIniter()

  • recurrent_kernel_init (Callable) – initializer function for the kernels that transform the hidden state, by default KaimingUniformIniter()

  • bias_init (Callable) – initializer for the bias parameters

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

initial_state(batch_size)#

Constructs an initial state for this cell.

Parameters:

batch_size (Optional[int]) – a batch dimension

class neurai.nn.layer.recurrent.RNN(cell, batch_first=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

The RNN module takes any RNNCellBase instance and applies it over a sequence.

Examples

from neurai.nn import RNN, GRUCell
from jax import random

input = random.normal(random.PRNGKey(0), (28, 512, 28))
rnn = RNN(GRUCell(input_size=28, hidden_size=128))
param = rnn.init(input)
out = rnn.run(param, input)
Parameters:
  • cell (RNNCellBase) – an instance of RNNCellBase.

  • batch_first (bool) – Whether batch is the first dimension, the default is false

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.recurrent.RNNCellBase(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

RNN cell base class, this class defines the basic functionality that every cell should implement.

Parameters:
  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

initial_state(batch_size)#

Constructs an initial state for this cell.

Parameters:

batch_size (Optional[int]) – a batch dimension

class neurai.nn.layer.recurrent.SRNN(neuron=None, dense_w=None, dense_r=None, out_dim=None, time_dim=-1, bias=False, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Spiking Recurrent Neural Network (SRNN) module for sequential data processing.

Examples

from neurai.nn import SNNLIF, SRNN
from jax import random

input = random.normal(random.PRNGKey(0), (32, 28, 28))
snnlif = SNNLIF(tau=2.0)
srnn = SRNN(snnlif, out_dim=1024)
param = srnn.init(input)
out = srnn.run(param, input)
Parameters:
  • neuron (Callable) – The spiking neuron model used in the SRNN.

  • dense_w (Callable, Optional) – The dense weights for input transformation.

  • dense_r (Callable, Optional) – The dense weights for recurrent connections.

  • out_dim (int, Optional) – The output dimension of the SRNN.

  • time_dim (int, Optional) – The time dimension. Default is -1.

  • bias (bool, Optional) – Flag to determine if bias is used. Default is False.

  • kernel_init (Callable, Optional) – The initializer for kernel weights. Default is KaimingUniformIniter(scale=math.sqrt(5), distribution=’leaky_relu’).

  • recurrent_kernel_init (Callable, Optional) – The initializer for recurrent kernel weights. Default is KaimingUniformIniter(scale=math.sqrt(5), distribution=’leaky_relu’).

  • bias_init (Callable, Optional) – The initializer for bias. Default is zero_func.

  • param_dtype (Any, Optional) – The data type for parameters. Default is jnp.float32.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

Returns:

(spike, mem) or spike – If the neuron property record_v is true, return spike and mem, otherwise return spike

class neurai.nn.layer.recurrent.SRNNCellBase(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

SRNN cell base class, this class defines the basic functionality that every cell should implement

Parameters:
initial_state(shape, neuron, dense_w, dense_r, param_dtype)#

Initialize the state of the SRNN cell.

Parameters:
  • shape (tuple) – Shape of the cell state.

  • neuron (Neuron) – The neuron model for the cell.

  • dense_w (Callable) – Function for dense weights.

  • dense_r (Callable) – Function for recurrent weights.

  • param_dtype (Any) – Data type for cell parameters.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

class neurai.nn.layer.slayer.Slayer(size=1, pid=0, V_th=10.0, V_reset=0.0, stdp_tau=20.0, tau_triplet=110.0, dendritic_delay=1, area_name=None, surrogate_grad_fn=<neurai.grads.surrogate_grad.SlayerPdf object>, id=0, neuron_start_id=1, V_rest=0.0, tau_res=10.0, tau_ref=1.0, time_step=1.0, time_window=100, kernel_window=38, record_v=False, computation_type=ComputationType.CUDA, scale_rho=2, tau_rho=1, scale_ref=2, full_ref_kernel=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: SNNSRM

SLAYER (Spike Layer Error Reassignment) algorithm. The base description of the framework has been published in [NeurIPS 2018] (https://nips.cc/Conferences/2018/Schedule?showEvent=11157).The final paper is available [here](http://papers.nips.cc/paper/7415-slayer-spike-layer-error-reassignment-in-time.pdf). The arXiv preprint is available [here](https://arxiv.org/abs/1810.08646).

Examples

from neurai.config import set_platform
from neurai.nn import Slayer
from jax import random, numpy as jnp

set_platform("gpu")
slayer = Slayer(V_th=10.0, time_step=1.0, time_window=100, scale_rho=2, tau_rho=1)
input = random.randint(random.PRNGKey(0), shape=(1, 2, 34, 34, 100), minval=0, maxval=2).astype(jnp.float32)
param = slayer.init(input)
out = slayer.run(param, input)
Parameters:
  • V_rest (float) – the resting membrane potential, by default 0.

  • V_th (int, float) – neuron threshold, by default 10.0

  • tau_res (float) – membrane time constant, by default 10.

  • tau_ref (float) – neuron refractory time constant, by default 1.

  • time_step (float) – sampling time, by default 1.

  • time_window (int) – time length of sample, by default 100.

  • scale_rho (int, float) – spike function derivative scale factor, by default 2.

  • tau_rho (int, float) – spike function derivative time constant (relative to theta), by default 1.

  • scale_ref (int, float) – neuron refractory response scaling (relative to theta)

  • full_ref_kernel (bool) – optional, high resolution refractory kernel (the user shall not use it in practice).

  • record_v (bool) – Whether to record the membrane potential values. Setting it to True will output both spike and membrane potential values, while setting it to False will only output spike values.

  • surrogate_grad_fn (Callable) – the function for calculating surrogate gradients, by default surrogate_grad.Sigmoid().

  • computation_type (str) – computation_type, available_types = [cuda, for_loop], slayer currently only supports for_loop, cuda computation type. slayer’s calculation type defaults to cuda.

  • kernel_window (int) – Calculate the maximum number of elements in the response kernel based on tau_res.

  • V_reset (float, Optional) – Reset voltage of this neurons layer, If it is not None, the voltage will be reset to V_reset after firing a spikie. If set to None, the voltage will be subtracted by V_th after firing a spikie, by default 0.

  • size (int, Optional) – The number of neurons in the population, by default 1.

  • stdp_tau (float = 20.) –

  • tau_triplet (float, Optional) – Time constant of long presynaptic trace, by default 110.

  • dendritic_delay (Union[int, jax.numpy.ndarray, Initializer, Callable], Optional) – The dendritic delay length, by default 1

  • pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.

  • area_name (str, Optional) – The name of the area to which the current neuron belongs. By default empty string.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

  • id (int) –

  • neuron_start_id (int) –

refractory_kernel(key=0, kernel_type='alpha')#

Calculate the refractory kernel. The refractory kernel is given by the following formula:

alpha kernel:

\[\nu(t) = -\theta \frac{t}{ \tau} \exp(1 - \frac{t}{\tau})\]

exp kernel:

\[\nu(t) = \frac{\theta}{ \tau} \exp(1 - \frac{t}{\tau})\]
Returns:

list

response_kernel(key=0, kernel_type='alpha')#

Calculate the PSP kernel. The PSP kernel is given by the following formula:

alpha kernel:

\[\epsilon(t) = \frac{t}{ \tau} \exp(1 - \frac{t}{\tau})\]

exp kernel:

\[\epsilon(t) = \frac{1}{ \tau} \exp(1 - \frac{t}{\tau})\]
Returns:

list

neurai.nn.layer.slayer.psp_forward_jvp(primals, tangents)#

Compute the Jacobian-vector product (JVP) for the custom JAX primitive psp_forward.

Parameters:
  • primals – The primal inputs for the psp_forward primitive.

  • tangents – The tangents for the JVP computation.

Returns:

tuple: A tuple containing the JVP for psp_forward and the JVP for the membrane potential.

Module contents#