neurai.nn.layer package#
Submodules#
- class neurai.nn.layer.activate.Acitivate(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
A base class for activation functions. This class inherits from Layer and defines a __call__ method that applies an activation function to the inputs.
- Parameters:
activate_fun (Callable) – The activation function to use. This should be a callable object that takes an input tensor and returns the tensor with the applied activation function.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Celu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for celu activation function. This class inherits from Acitivate and uses jax.nn.celu as the default activation function.
Examples
from neurai.nn import Celu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Celu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.celu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(alpha=1.0)#
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- class neurai.nn.layer.activate.Elu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for elu activation function. This class inherits from Acitivate and uses jax.nn.elu as the default activation function.
Examples
from neurai.nn import Elu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Elu()(x, alpha=0.5)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.elu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(alpha=1.0)#
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
- class neurai.nn.layer.activate.Gelu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for gelu activation function. This class inherits from Acitivate and uses jax.nn.gelu as the default activation function.
Examples
from neurai.nn import Gelu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Gelu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.gelu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(approximate=True)#
Gaussian error linear unit activation function.
If
approximate=False
, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]If
approximate=True
, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- class neurai.nn.layer.activate.Glu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for glu activation function. This class inherits from Acitivate and uses jax.nn.glu as the default activation function.
Examples
from neurai.nn import Glu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32) y = Glu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.glu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.HardSigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for hard_sigmoid activation function. This class inherits from Acitivate and uses jax.nn.hard_sigmoid as the default activation function.
Examples
from neurai.nn import HardSigmoid from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32) y = HardSigmoid()(x)
expected output:
Array([[0.33333334, 0.8333334 ], [1. , 0. ]], dtype=float32)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_sigmoid.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.HardSilu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for hard_silu activation function. This class inherits from Acitivate and uses jax.nn.hard_silu as the default activation function.
Examples
from neurai.nn import HardSilu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = HardSilu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_silu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.HardSwish(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for hard_swish activation function. This class inherits from Acitivate and uses jax.nn.hard_swish as the default activation function.
Examples
from neurai.nn import HardSwish from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = HardSwish()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_swish.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.HardTanh(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for hard_tanh activation function. This class inherits from Acitivate and uses jax.nn.hard_tanh as the default activation function.
Examples
from neurai.nn import HardTanh from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = HardTanh()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.hard_tanh.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.LeakyRelu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for leaky relu activation function. This class inherits from Acitivate and uses jax.nn.leaky_relu as the default activation function.
Examples
from neurai.nn import LeakyRelu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = LeakyRelu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.leaky_relu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(negative_slope=0.01)#
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.
- class neurai.nn.layer.activate.LogSigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for log_sigmoid activation function. This class inherits from Acitivate and uses jnp.log_sigmoid as the default activation function.
Examples
from neurai.nn import LogSigmoid from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = LogSigmoid()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.log_sigmoid.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.LogSoftmax(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for log_softmax activation function. This class inherits from Acitivate and uses jax.nn.log_softmax as the default activation function.
Examples
from neurai.nn import LogSoftmax from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = LogSoftmax()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.log_softmax.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(axis=-1, where=None, initial=None)#
Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters:
x (
Any
) – input arrayaxis (
Union
[int
,Tuple
[int
,...
],None
]) – the axis or axes along which thelog_softmax
should be computed. Either an integer or a tuple of integers.where (
Optional
[Any
]) – Elements to include in thelog_softmax
.initial (
Optional
[Any
]) – The minimum value used to shift the input array. Must be present whenwhere
is not None.
- Return type:
- class neurai.nn.layer.activate.LogSumexp(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for logsumexp activation function. This class inherits from Acitivate and uses jax.nn.logsumexp as the default activation function.
Examples
from neurai.nn import LogSumexp from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = LogSumexp()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.logsumexp.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(axis=None, b=None, keepdims=False, return_sign=False)#
Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
scipy.special.logsumexp()
.Original docstring below.
- Parameters:
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Return type:
- Returns:
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
- class neurai.nn.layer.activate.Mish(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for mish activation function. This class inherits from Acitivate and uses mish as the default activation function.
Examples
from neurai.nn import Mish from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Mish()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is mish.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Relu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for relu activation function. This class inherits from Acitivate and uses jax.nn.relu as the default activation function.
Examples
from neurai.nn import Relu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Relu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.relu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Relu6(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for relu6 activation function. This class inherits from Acitivate and uses jax.nn.relu6 as the default activation function.
Examples
from neurai.nn import Relu6 from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Relu6()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.relu6.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Selu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for selu activation function. This class inherits from Acitivate and uses jax.nn.selu as the default activation function.
Examples
from neurai.nn import Selu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Selu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.selu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun()#
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- class neurai.nn.layer.activate.Sigmoid(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for sigmoid activation function. This class inherits from Acitivate and uses jax.nn.sigmoid as the default activation function.
Examples
from neurai.nn import Sigmoid from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]], dtype=jnp.float32) y = Sigmoid()(x)
expect output:
Array([[0.26894143, 0.8807971 ], [0.95257413, 0.01798621]], dtype=float32)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.sigmoid.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Silu(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for silu activation function. This class inherits from Acitivate and uses jax.nn.silu as the default activation function.
Examples
from neurai.nn import Silu from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Silu()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.silu.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.SoftSign(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for soft_sign activation function. This class inherits from Acitivate and uses jax.nn.soft_sign as the default activation function.
Examples
from neurai.nn import SoftSign from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = SoftSign()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.soft_sign.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Softmax(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for softmax activation function. This class inherits from Acitivate and uses jax.nn.softmax as the default activation function.
Examples
from neurai.nn import Softmax from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Softmax()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.softmax.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun(axis=-1, where=None, initial=None)#
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
x (
Any
) – input arrayaxis (
Union
[int
,Tuple
[int
,...
],None
]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.initial (
Optional
[Any
]) – The minimum value used to shift the input array. Must be present whenwhere
is not None.
- Return type:
- class neurai.nn.layer.activate.Softmin(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for softmin activation function. This class inherits from Acitivate and uses softmin as the default activation function.
Examples
from neurai.nn import Softmin from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Softmin()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is softmin.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Softplus(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for softplus activation function. This class inherits from Acitivate and uses jax.nn.softplus as the default activation function.
Examples
from neurai.nn import Softplus from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Softplus()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.softplus.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.activate.Tanh(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Acitivate
A class for tanh activation function. This class inherits from Acitivate and uses jax.nn.tanh as the default activation function.
Examples
from neurai.nn import Tanh from jax import numpy as jnp x = jnp.array([[-1, 2], [3, -4]]) y = Tanh()(x)
- Parameters:
activate_fun (Callable) – The activation function to use. Default is jax.nn.tanh.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- activate_fun()#
Compute hyperbolic tangent element-wise.
LAX-backend implementation of
numpy.tanh()
.Original docstring below.
Equivalent to
np.sinh(x)/np.cosh(x)
or-1j * np.tan(1j*x)
.- Parameters:
x (array_like) – Input array.
- Returns:
y (ndarray) – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar.
References
- class neurai.nn.layer.attention.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=VarianceScalingIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Class for an embedding layer in a neural network.
Examples
from neurai.nn import Embed from jax import numpy as jnp input = jnp.asarray([9, 6, 5, 7, 8, 8, 9, 2, 8]) embed = Embed(num_embeddings=10, features=12) param = embed.init(input) out = embed.run(param, input)
- Parameters:
num_embeddings (int) – The size of the vocabulary or number of distinct items to embed.
features (int) – The dimensionality of the embedding.
dtype (Optional[Any], Optional) – The dtype of the input, by default None.
param_dtype (Any, Optional) – The dtype of the embedding parameters, by default jnp.float32.
embedding_init (Callable[..., jnp.ndarray], Optional) – The initializer for the embedding parameters, by default default_embed_init.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- attend(query)#
Attend over the embedding using a query array.
- Parameters:
query (jnp.ndarray) – Array with the last dimension equal to the feature depth features of the embedding.
- Return type:
- Returns:
jnp.ndarray – An array with the final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.
- Raises:
TypeError – If the input query and the embedding have different data types.
- neurai.nn.layer.attention.canonicalize_dtype(*args, dtype=None, inexact=True)#
Canonicalize an optional dtype to the definitive dtype.
If the dtype is None, this function will infer the dtype from the input arguments using jnp.result_type(). If dtype is not None, it will be returned unmodified, or an exception will be raised if the dtype is invalid.
- Parameters:
*args (JAX array compatible values) – Input values to infer the dtype from. None values are ignored.
dtype (Optional dtype override) – If specified, the input arguments are cast to the specified dtype instead, and dtype inference is disabled.
inexact (bool) – When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’t work directly on integers like taking a mean for example.
- Return type:
- Returns:
The dtype that *args should be cast to.
- neurai.nn.layer.attention.masked_softmax(X, valid_lens)#
Perform softmax operation by masking elements on the last axis.
- Parameters:
X (np.ndarray) – The input 3D tensor to be softmaxed.
valid_lens (np.ndarray) – The 1D or 2D tensor containing the valid length for each sequence.
- Returns:
np.ndarray – The softmaxed tensor.
- neurai.nn.layer.attention.promote_dtype(*args, dtype=None, inexact=True)#
Promotes input arguments to a specified or inferred dtype.
All args are cast to the same dtype. See canonicalize_dtype for how this dtype is determined.
The behavior of promote_dtype is mostly a convenience wrapper around neurai.numpy.promote_types. The differences being that it automatically casts all input to the inferred dtypes, allows inference to be overridden by a forced dtype, and has an optional check to guarantee the resulting dtype is inexact.
- Parameters:
*args (Tuple) – JAX array compatible values. None values are returned as is.
dtype (Optional dtype override.) – If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.
inexact (bool) – When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’t work directly on integers like taking a mean for example.
- Return type:
- Returns:
List ([jnp.ndarray]) – The arguments cast to arrays of the same dtype.
- neurai.nn.layer.attention.sequence_mask(X, valid_len, value=0)#
Create a mask tensor for a given sequence length.
- neurai.nn.layer.attention.transpose_output(X, num_heads)#
Transpose the output tensor after parallel computation of multiple attention heads.
- Parameters:
X (np.ndarray) – The input tensor to be transposed.
num_heads (int) – The number of attention heads.
- Returns:
np.ndarray – The transposed tensor.
- neurai.nn.layer.attention.transpose_qkv(X, num_heads)#
Transposition for parallel computation of multiple attention heads.
- Parameters:
X (np.ndarray) – The input tensor to be transposed.
num_heads (int) – The number of attention heads.
- Returns:
np.ndarray – The transposed tensor.
- class neurai.nn.layer.conv.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Applies a convolution to the inputs.
- Parameters:
features (int) – Number of convolution filters (output channels).
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- check_kernel_dimensions(kernel, size)#
Check if the kernel dimensions match the expected size.
- Parameters:
- Raises:
ValueError – If kernel does not have the expected size.
- class neurai.nn.layer.conv.Conv1d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv
1-dimensional convolution layer.
Examples
from neurai.nn import Conv1d from jax import random input = random.normal(random.PRNGKey(0), (1, 5, 1)) conv = Conv1d(features=4, kernel_size=(3,)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of convolution filters (output channels).
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Raises:
ValueError – If kernel_size is not a single integer or a sequence of a single integer.
- class neurai.nn.layer.conv.Conv2d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv
Convolutional layer for 2D inputs.
Examples
from neurai.nn import Conv2d from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 1)) conv = Conv2d(3, kernel_size=(3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of convolution filters (output channels).
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Raises:
ValueError – If kernel_size is not an integer or sequence of 2 integers.
- class neurai.nn.layer.conv.Conv3d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv
3D convolution layer.
Examples
from neurai.nn import Conv3d from jax import random input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1)) conv = Conv3d(3, kernel_size=(3, 3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of convolution filters (output channels).
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Raises:
ValueError – If kernel_size is not an integer or sequence of 3 integers.
- class neurai.nn.layer.conv.ConvTranspose(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Applies a transpose convolution to the inputs.
- Parameters:
features (int) – The number of output channels (i.e., the number of filters/kernels).
kernel_size (Sequence[int]) – The dimensions of the convolutional kernel (e.g., [3, 3] for a 2D kernel).
strides (Union[None, int, Sequence[int]], Optional) – The stride of the convolution. Controls the step size at which the kernel is applied to the input. Default is 1, which means no overlapping.
padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.
output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.
use_bias (bool, Optional) – Whether to include a bias term in the convolution. Default is True.
rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.
kernel_init (Callable, Optional) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution initializer.
bias_init (Callable, Optional) – The initializer function for the bias term (if use_bias=True). Default is a normal distribution initializer.
mask (Optional[jnp.ndarray], Optional) – A mask for the kernel. If provided, it will be element-wise multiplied with the kernel matrix. Default is None, meaning no masking is applied.
precision (Optional[lax.Precision], Optional) – The numerical precision of the computation, used for high-precision calculations. Default is None, which uses the default precision.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- __call__(self, input: jnp.ndarray) jnp.ndarray: #
Performs the transpose convolution operation on the input.
- class neurai.nn.layer.conv.ConvTranspose1d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
ConvTranspose
One-dimensional transposed convolution layer.
Examples
from neurai.nn import ConvTranspose1d from jax import random input = random.normal(random.PRNGKey(0), (1, 5, 4)) conv = ConvTranspose1d(features=10, kernel_size=(1,)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – The number of output channels.
kernel_size (Union[int, Tuple[int]]) – The dimensions of the convolutional kernel.
strides (Union[None, int, Tuple[int]]) – The stride of the convolution. Default is 1.
padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.
output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.
use_bias (bool) – Whether to include a bias term. Default is True.
rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.
kernel_init (Callable) – The initializer function for the convolutional kernel. Default is a normal distribution.
bias_init (Callable) – The initializer function for the bias term. Default is a normal distribution.
mask (Optional[jnp.ndarray]) – A mask for the kernel. Default is None.
precision (Optional[jax.lax.Precision]) – The numerical precision of the computation. Default is None.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
out_padding (
Union
[int
,Tuple
[int
,int
],Sequence
[Tuple
[int
,int
]]]) –param_dtype (
Any
) –
- Raises:
ValueError – If kernel_size is not an integer or a tuple of integers.
- class neurai.nn.layer.conv.ConvTranspose2d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
ConvTranspose
2D convolution transpose layer.
Examples
from neurai.nn import ConvTranspose2d from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 1)) conv = ConvTranspose2d(features=3, kernel_size=(3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – The number of output channels.
kernel_size (Union[int, Tuple[int]]) – The dimensions of the convolutional kernel. If an int is provided, the kernel will have equal dimensions.
strides (Union[None, int, Tuple[int]]) – The stride of the convolution. Default is 1, which means no overlapping.
padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.
output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.
use_bias (bool) – Whether to include a bias term. Default is True.
rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.
kernel_init (Callable) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution.
bias_init (Callable) – The initializer function for the bias term. Default is a normal distribution.
mask (Optional[jnp.ndarray]) – A mask for the kernel. Default is None, meaning no masking is applied.
precision (Optional[jax.lax.Precision]) – The numerical precision of the computation. Default is None, which uses the default precision.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- setup(self):
Check and set up the configuration for the convolution transpose layer.
- class neurai.nn.layer.conv.ConvTranspose3d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=True, rhs_dilation=1, kernel_init=<function normal.<locals>.init>, bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
ConvTranspose
Convolutional transpose layer for 3D inputs.
Examples
from neurai.nn import ConvTranspose3d from jax import random input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1)) conv = ConvTranspose3d(features=3, kernel_size=(3, 3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – The number of output channels (i.e., the number of filters/kernels).
kernel_size (Union[int, Tuple[int], List[int]]) – The dimensions of the convolutional kernel (e.g., [3, 3, 3] for a 3D kernel). If a single integer is provided, it will be replicated for all three dimensions.
strides (Union[None, int, Tuple[int], List[int]]) – The stride of the convolution. Controls the step size at which the kernel is applied to the input. If a single integer is provided, it will be replicated for all three dimensions. Default is 1, which means no overlapping.
padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Zero-padding will be added to both sides of the input. Default is 0.
output_padding (Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Additional size added to one side of the output shape. Default is 0.
use_bias (bool) – Whether to include a bias term in the convolution. Default is True.
rhs_dilation (Union[None, int, Sequence[int]]) – Spacing between kernel elements. Default is 1.
kernel_init (Callable) – The initializer function for the convolutional kernel (weight matrix). Default is a normal distribution initializer.
bias_init (Callable) – The initializer function for the bias term (if use_bias=True). Default is a normal distribution initializer.
mask (Optional[jnp.ndarray]) – A mask for the kernel. If provided, it will be element-wise multiplied with the kernel matrix. Default is None, meaning no masking is applied.
precision (Optional[lax.Precision]) – The numerical precision of the computation, used for high-precision calculations. Default is None, which uses the default precision.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- setup(self):
Check and set up the configuration for the convolution transpose layer.
- class neurai.nn.layer.conv.SNNConv2d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv2d
2D convolutional layer for SNN (Spiking Neural Network) inputs. This SNNConv2d convolution accepts 5-dimensional or 4-dimensional inputs; When the input data is 5-dimensional, merge the Batch and Depth dimensions and use 2-dimensional convolution for calculation; The calculation process of 4-dimensional input data is consistent with ANN.
Examples
from neurai.nn import SNNConv2d from jax import random input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1)) conv = SNNConv2d(features=3, kernel_size=(3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of output channels.
weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.
kernel_init (Callable, Optional) – Initialization function for the kernel parameters. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
use_bias (bool, Optional) – Whether to include bias in the layer. Default is False.
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution withstride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.conv.SNNConv3d(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv3d
3D convolutional layer for SNN (Spiking Neural Network) inputs.
Examples
from neurai.nn import SNNConv3d from jax import random input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1)) conv = SNNConv3d(features=3, kernel_size=(3, 3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of output channels.
weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.
kernel_init (Callable, Optional) – Initialization function for the kernel parameters. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
use_bias (bool, Optional) – Whether to include bias in the layer. Default is False.
kernel_size (Sequence[int]) – The shape of the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution withstride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
kernel_init – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.conv.SNNConvTranspose3d(features, kernel_size, strides=1, padding=0, out_padding=0, use_bias=False, rhs_dilation=1, kernel_init=KaimingUniformIniter(key=None), bias_init=<function normal.<locals>.init>, mask=None, precision=None, param_dtype=<class 'jax.numpy.float32'>, weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
ConvTranspose3d
3D convolutional layer for SNN (Spiking Neural Network) inputs.
Examples
from neurai.nn import SNNConv3d from jax import random input = random.normal(random.PRNGKey(0), (1, 6, 28, 28, 1)) conv = SNNConv3d(features=3, kernel_size=(3, 3, 3)) param = conv.init(input) out = conv.run(param, input)
- Parameters:
features (int) – Number of output channels.
weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.
use_bias (bool, Optional) – Whether to include bias in the layer. Default is False.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’leaky_relu’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
padding (
Union
[int
,Tuple
[int
,int
],Sequence
[Tuple
[int
,int
]]]) –out_padding (
Union
[int
,Tuple
[int
,int
],Sequence
[Tuple
[int
,int
]]]) –bias_init (
Callable
) –param_dtype (
Any
) –
- neurai.nn.layer.conv.canonicalize_padding(padding, rank)#
Canonicalizes conv padding to a jax.lax supported format.
- neurai.nn.layer.conv.maybe_replicate(x, rep)#
Replicate an integer or a sequence of integers rep times.
- class neurai.nn.layer.dropout.Dropout(rate, deterministic=None, rng_col='dropout', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).
This layer is active only during training (mode=brainpy.modes.training). In other circumstances, it is a no-op.
Examples
from neurai.nn import Dropout from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) dropout = Dropout(rate=0.5, deterministic=False) param = dropout.init(input, rngs={"dropout": 1}) out = dropout.run(param, input, rngs={"dropout": 1})
- Parameters:
rate (float) – The dropout probability.
deterministic (Optional[bool], Optional) – If False, the inputs are scaled by 1 / (1 - rate) and masked. If True, no mask is applied, and the inputs are returned as is. Default is None.
rng_col (str, Optional) – The rng collection name to use when requesting an rng key. Default is ‘dropout’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Returns:
The masked inputs.
- class neurai.nn.layer.dropout.SNNDropout3d(rate, deterministic=None, rng_col='dropout', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Dropout
- 3D dropout layer for SNN (Spiking Neural Network) inputs. Dropout is applied to the input, with dropout over the time dimension being preserved,
meaning that if a neuron is discarded, it remains excluded throughout the entire duration of time.
Examples
from neurai.nn import SNNDropout3d from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) dropout = SNNDropout3d(rate=0.5, deterministic=False) param = dropout.init(input, rngs={"dropout": 1}) out = dropout.run(param, input, rngs={"dropout": 1})
- Parameters:
rate (float) – The dropout probability.
deterministic (Optional[bool], Optional) – If False, the inputs are scaled by 1 / (1 - rate) and masked. If True, no mask is applied, and the inputs are returned as is. Default is None.
rng_col (str, Optional) – The rng collection name to use when requesting an rng key. Default is ‘dropout’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Returns:
The masked inputs.
- class neurai.nn.layer.exodus.Exodus(size=1, pid=0, V_th=10.0, V_reset=0.0, stdp_tau=20.0, tau_triplet=110.0, dendritic_delay=1, area_name=None, surrogate_grad_fn=<neurai.grads.surrogate_grad.SingleExponential object>, id=0, neuron_start_id=1, record_v=False, step_mode=StepMode.MULTI, reset_type=ResetType.HARD, time_recurrent=False, before_reset_mem=None, param_dtype=<class 'jax.numpy.float32'>, V_rest=0.0, R=1.0, tau=20.0, min_v_mem=0.0, norm_input=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
SNNLIF
Exodus implementation of a spiking Leaky Integrate and Fire neuron. Does not simulate synaptic dynamics.
EXact computation Of Derivativesas Update to SLAYER. Paper work: [2205.10242v1] EXODUS: Stable and Efficient Training of Spiking Neural Networks (https://arxiv.org/abs/2205.10242v1). Code work: synsense/sinabs-exodus.
Examples
- ..code-block:: python
from neurai.nn import Exodus from jax import random exodus = Exodus(V_rest=0., V_reset=0., V_th=1.0, tau=1.0, record_v=False, min_v_mem=0) input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 4)) param = exodus.init(input) out = exodus.run(param, input)
- Parameters:
V_rest (float, Optional) – The initial value of membrane potential. Default is 0.
V_reset (float, Optional) – The reset membrane potential. Default is 0.
V_th (float, Optional) – The firing threshold. Default is 1.0.
tau (Union[float, jnp.ndarray], Optional) – The membrane potential time constant.
record_v (bool, Optional) – Whether to record the membrane potential values. Setting it to True will output both spike and membrane potential values, while setting it to False will only output spike values. Default is False.
surrogate_grad_fn (Callable, Optional) – The function for calculating surrogate gradients of the heaviside step function in backward. Default is surrogate_grad.SingleExponential().
min_v_mem (float or None, Optional) – Lower bound for the membrane potential v_mem, clipped at every time step. Default is None.
norm_input (bool, Optional) – If True, will normalize the inputs by tau. This helps when training time constants. Default is False.
max_num_spikes_per_bin (int, Optional) – Maximum number of neurons that a neuron can emit per time step.
activations (int, Optional) – Activations from the previous time step. Has to be contiguous. Default is 1.
step_mode (neurai.const.StepMode, Optional) – SNN LIF Neuron update mode, neurai.const.single for single-step update or neurai.const.multi for multi-step update. In single-step mode, the update process includes calculations for only one time step, whil in multi-step mode, it calculates the update process for T time steps. Default is neurai.const.multi.
R (float) – membrane resistance, by default 1.
reset_type (Union[None, neurai.const.ResetType]) – How to reset the membrane potential after the spike, with values of neurai.const.hard or neurai.const.soft. The default is neurai.const.hard.
time_recurrent (bool) – The time_recurrent parameter is used to verify the calculation mode of the LIF model in the SRNN structure.
before_reset_mem (bool) – A Boolean value used to determine whether the membrane potential is limited below the threshold after a neuron firing spikes.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
size (int, Optional) – The number of neurons in the population, by default 1.
stdp_tau (float = 20.) –
tau_triplet (float, Optional) – Time constant of long presynaptic trace, by default 110.
dendritic_delay (Union[int, jax.numpy.ndarray, Initializer, Callable], Optional) – The dendritic delay length, by default 1
pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.
area_name (str, Optional) – The name of the area to which the current neuron belongs. By default empty string.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
id (
int
) –neuron_start_id (
int
) –
- Raises:
ValueError – If the provided input shape is empty. If V_th is not greater than min_v_mem (if min_v_mem is specified).
Exception – If in_features or out_features have more than 3 dimensions (for the dense method).
- get_alpha(template, tau)#
Calculates alpha from tau.
- Parameters:
template (jnp.Array) – Specify the shape of the alpha.
tau (float) – Membrane potential time constant.
- Returns:
jnp.Array
- init_neuron_state(input_shape)#
Initialize mem and spike for a single time step. theta is a scalar of data type float.
- Parameters:
input_shape (jnp.ndarray) –
- neurai.nn.layer.exodus.exodus_cpu_bwd(res, posterior_grad)#
The number of posteriorGrads is consistent with the number of return values for forward. If the return value of forward is greater than 1, posteriorGrad is a tuple. Backward pass for Exodus CPU forward computation.
- Parameters:
res (Tuple) – Output of the Exodus CPU forward computation. It contains the following elements:
posterior_grad (Tuple) – posterior_grad[0], jnp.ndarray. The gradient of the loss with respect to the output spike tensor of shape (neurons, Timestep). posterior_grad[1], jnp.ndarray. The gradient of the loss with respect to the output membrane potential tensor of shape (neurons, Timestep).
- Returns:
Tuple – The gradients of the loss with respect to the inputs of the Exodus CPU forward computation. It contains the following elements: input_grad : jnp.ndarray. The gradient of the loss with respect to the input spike tensor of shape (neurons, Timestep). init_grad : jnp.ndarray. The gradient of the loss with respect to the input membrane potential tensor of shape (neurons, Timestep). alpha_grad : jnp.ndarray. The gradient of the loss with respect to the input alpha tensor of shape (neurons,). None * 4 : Elements with None to maintain consistency with the number of inputs in the forward computation.
- neurai.nn.layer.exodus.exodus_cpu_fwd(input_spike, mem_initial, alpha, membrane_subtract, threshold, theta_low, spikes_num)#
Forward implementation of Exodus CPU function.
- Parameters:
input_spike (jnp.ndarray) – Input spike tensor of shape (neurons, Timestep) representing the spikes received by each neuron at each timestep.
mem_initial (jnp.ndarray) – Initial membrane potential tensor of shape (neurons, Timestep) representing the starting membrane potential of each neuron.
alpha (jnp.ndarray) – Alpha tensor of shape (neurons, Timestep) representing the decay factor for each neuron’s membrane potential.
membrane_subtract (jnp.ndarray) – Tensor of shape (neurons, Timestep) representing the value to be subtracted from the membrane potential due to spikes.
threshold (float) – The threshold value for neuron firing.
theta_low (float) – Lower threshold for the membrane potential.
spikes_num (int) – The maximum number of spikes a neuron can emit per time step.
- Returns:
Tuple
output_spike (jnp.ndarray) – The output spike tensor of shape (neurons, Timestep) representing the spikes emitted by each neuron at each timestep.
output_mem (jnp.ndarray) – The post-synaptic potential tensor of shape (neurons, Timestep) representing the membrane potential of each neuron after spike processing.
input_tuple (Tuple) – A tuple containing the input values needed for the backward pass during gradient computation.
- class neurai.nn.layer.linear.Flatten(start_dim=1, end_dim=-1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Initializes a Flatten module that flattens a jnp.ndarray along the specified dimensions.
Examples
from neurai.nn import Flatten from jax import random flatten = Flatten() input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) param = flatten.init(input) out = flatten.run(param, input)
- Parameters:
start_dim (int, Optional) – The starting dimension to flatten, by default 1.
end_dim (int, Optional) – The ending dimension to flatten, by default -1. If -1, all dimensions starting from start_dim are flattened.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.linear.Linear(features, bias=True, param_dtype=<class 'jax.numpy.float32'>, w_initializer=KaimingUniformIniter(key=None), b_initializer=UniformIniter(key=None), parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
A linear transformation Layer.
This Layer applies a linear transformation to the input data, optionally adding a bias vector.
Examples
from neurai.nn import Linear from jax import random linear = Linear(10) input = random.normal(random.PRNGKey(0), (1, 784)) param = linear.init(input) out = linear.run(param, input)
- Parameters:
features (int) – The number of output features.
bias (bool, Optional) – Whether to use a bias vector. Default is True.
param_dtype (Any, Optional) – The dtype passed to parameter initializers (default: float32).
w_initializer (Callable, Optional) – The initializer for the weight matrix. Default is KaimingUniformIniter().
b_initializer (Optional[Callable], Optional) – The initializer for the bias vector. Default is UniformIniter(-0.08, 0.08).
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.linear.SNNLinear3d(features, kernel_size=None, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=KaimingUniformIniter(key=None), bias_init=KaimingUniformIniter(key=None), weight_scale=1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv3d
3D linear layer for SNN (Spiking Neural Network) inputs.
Examples
from neurai.nn import SNNLinear3d from jax import random snn_linear = SNNLinear3d(features=10) input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 5)) param = snn_linear.init(input) out = snn_linear.run(param, input)
- Parameters:
features (int) – Number of output channels.
kernel_size (Sequence[int]) – The shape of the convolutional kernel. Default is None.
weight_scale (float, Optional) – Scale factor for the kernel weights. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
padding (
Union
[str
,Tuple
[int
,int
],Sequence
[Tuple
[int
,int
]]]) –feature_group_count (
int
) –param_dtype (
Any
) –precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) –kernel_init (
Callable
) –bias_init (
Callable
) –
- class neurai.nn.layer.linear.Sequential(layers, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
A Layer that applies a sequence of other Layers.
This Layer takes a variable number of other Layers as arguments, and applies them one by one to the input data.
Examples
from neurai.nn import Sequential, Linear from jax import random model = Sequential([ Linear(1024), Linear(512), Linear(10)]) input = random.normal(random.PRNGKey(0), (1, 784)) param = model.init(input) out = model.run(param, input) print(out.shape)
- Parameters:
layers (Sequence[Callable[..., Any]]) – The Layers to run sequentially.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.BatchNorm(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Batch Normalization layer.
- Parameters:
use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis (int) – The batch axis of the input.
momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True
param_dtype (Any) – The dtype passed to parameter initializers (default: float32).
axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.
axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.
scale_init (Callable) – An initializer generating the original scaling matrix.
bias_init (Callable) – An initializer generating the original translation matrix.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.BatchNorm1d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
BatchNorm
1D Batch Normalization layer.
Examples
from neurai.nn import BatchNorm1d from jax import random bn1 = BatchNorm1d(use_running_average=False) input = random.normal(random.PRNGKey(0), (10, 20)) param = bn1.init(input) out = bn1.run(param, input)
- Parameters:
use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis (int) – The batch axis of the input.
momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True
param_dtype (Any) – The dtype passed to parameter initializers (default: float32).
axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.
axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.
scale_init (Callable) – An initializer generating the original scaling matrix.
bias_init (Callable) – An initializer generating the original translation matrix.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.BatchNorm2d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
BatchNorm
2D Batch Normalization layer.
Examples
from neurai.nn import BatchNorm2d from jax import random bn2 = BatchNorm2d(use_running_average=False) input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) param = bn2.init(input) out = bn2.run(param, input)
- Parameters:
use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis (int) – The batch axis of the input.
momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True
param_dtype (Any) – The dtype passed to parameter initializers (default: float32).
axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.
axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.
scale_init (Callable) – An initializer generating the original scaling matrix.
bias_init (Callable) – An initializer generating the original translation matrix.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.BatchNorm3d(use_running_average=None, axis=-1, epsilon=1e-05, momentum=0.99, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
BatchNorm
3D Batch Normalization layer.
Examples
from neurai.nn import BatchNorm3d from jax import random bn3 = BatchNorm3d(use_running_average=False) input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 16)) param = bn3.init(input) out = bn3.run(param, input)
- Parameters:
use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis (int) – The batch axis of the input.
momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True
param_dtype (Any) – The dtype passed to parameter initializers (default: float32).
axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.
axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.
scale_init (Callable) – An initializer generating the original scaling matrix.
bias_init (Callable) – An initializer generating the original translation matrix.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.LayerNorm(epsilon=1e-06, param_dtype=<class 'jax.numpy.float32'>, affine=True, bias_init=<function LayerNorm.<lambda>>, scale_init=<function LayerNorm.<lambda>>, reduction_axes=-1, feature_axes=-1, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Layer normalization (https://arxiv.org/abs/1607.06450).
LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
Examples
from neurai.nn import LayerNorm from jax import random layer_norm = LayerNorm() input = random.normal(random.PRNGKey(0), (10, 28, 28, 1)) param = layer_norm.init(input) out = layer_norm.run(param, input)
- Parameters:
epsilon (float, Optional) – A small float added to variance to avoid dividing by zero. Default: 1e-6.
param_dtype (Any, Optional) – The dtype passed to parameter initializers (default: float32).
affine (bool, Optional) – A boolean value that when set to True, this Layer has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: True.
bias_init (Callable, Optional) – Initializer for bias. Default: lambda _, shape, dtype: jnp.zeros(shape, dtype).
scale_init (Callable, Optional) – Initializer for scale. Default: lambda _, shape, dtype: jnp.ones(shape, dtype).
reduction_axes (Union[int, Any], Optional) – Axes for computing normalization statistics. Default: -1.
feature_axes (Union[int, Any], Optional) – Feature axes for learned bias and scaling. Default: -1.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.TdBatchNorm(use_running_average=None, axis=-2, epsilon=1e-05, momentum=0.1, affine=True, param_dtype=<class 'jax.numpy.float32'>, axis_name=None, axis_index_groups=None, scale_init=<function BatchNorm.<lambda>>, bias_init=<function BatchNorm.<lambda>>, alpha=1, v_th=0.2, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
BatchNorm2d
Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.
Examples
from neurai.nn import TdBatchNorm from jax import random tdbn = TdBatchNorm(use_running_average=False) input = random.normal(random.PRNGKey(0), (10, 28, 28, 3, 3)) param = tdbn.init(input) out = tdbn.run(param, input)
- Parameters:
use_running_average (Optional[bool]) – If True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis (int) – The batch axis of the input.
momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
affine (bool) – A boolean value that when set to True, this Layer has learnable affine parameters. Default: True
param_dtype (Any) – The dtype passed to parameter initializers (default: float32).
v_th (float) – membrane potential threshold, by default 0.2
alpha (float) – an addtional parameter which may change in resblock.
axis_name (Optional[Union[str, Sequence[str]]]) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this Layer is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.
axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within jax.pmap collectives.
scale_init (Callable) – An initializer generating the original scaling matrix.
bias_init (Callable) – An initializer generating the original translation matrix.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.TdLayer(layer, bn=None, use_running_average=None, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
- Converts a common layer to the time domain. The input tensor needs to have an additional time dimension,
which in this case is on the last dimension of the data. During the forward pass, a normal layer is performed for each time step of the data in that time dimension. Link to related paper: https://arxiv.org/pdf/2011.05280.
Examples
from neurai.nn import Conv2d, TdBatchNorm, TdLayer from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 10)) model = TdLayer(Conv2d(3, kernel_size=(3, 3)), TdBatchNorm(), use_running_average=False) param = model.init(input) out = model.run(param, input)
- Parameters:
layer (Module) – The layer needs to be converted.
bn (Module) – If Batch Normalization (BN) is needed, the BN layer should be passed in together as a parameter.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.normalization.WeightNorm(module, module_name='weight', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Weight normalization ( https://arxiv.org/abs/1602.07868).
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by
module_name
(e.g.'weight'
) with two parameters: one specifying the magnitude (e.g.'weight_g'
) and one specifying the direction (e.g.'weight_v'
).Examples
from neurai.nn import WeightNorm, Conv2d from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) wn = WeightNorm(Conv2d(3, kernel_size=(3, 3))) param = wn.init(input) out = wn.run(param, input)
- Parameters:
module (Module) – The containing module.
module_name (str, Optional) – The name of weight parameter, by default is ‘weight’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.padding.ConstantPad(padding, constant_values=0.0, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Constant padding layer for neural networks.
Examples
from neurai.nn import ConstantPad from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pad = ConstantPad(padding=1, constant_values=0.5) # pad = ConstantPad(padding=((2, 2), (2, 2)), constant_values=0.3) param = pad.init(input) out = pad.run(param, input)
- Parameters:
padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.
constant_values (float) – The constant value of the input tensor boundaries.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.padding.ReflectionPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Reflection padding layer for neural networks.
Examples
from neurai.nn import ReflectionPad from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pad = ReflectionPad(padding=2) # pad = ReflectionPad(padding=((2, 2), (2, 2))) param = pad.init(input) out = pad.run(param, input)
- Parameters:
padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.padding.ReplicationPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Replication padding layer for neural networks.
Examples
from neurai.nn import ReplicationPad from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pad = ReplicationPad(padding=2) # pad = ReplicationPad(padding=((2, 2), (2, 2))) param = pad.init(input) out = pad.run(param, input)
- Parameters:
padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.padding.ZeroPad(padding, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Zero padding layer for neural networks.
Examples
from neurai.nn import ZeroPad from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pad = ZeroPad(padding=1) # pad = ZeroPad(padding=((1, 1), (1, 1))) param = pad.init(input) out = pad.run(param, input) print(out.shape)
- Parameters:
padding (Union[int, Sequence[Tuple[int, int]]]) – Padding for the pooling operation. If is int, uses the same padding in all boundaries. If a tuple, uses(padding_left, padding_right) in 1d data, uses((padding_top, padding_bottom), (padding_left, padding_right)) in 2d data, uses((padding_front, padding_back), (padding_top, padding_bottom), (padding_left, padding_right)) in 3d data.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.pooling.AvgPool(kernel_size, strides=None, padding='VALID', count_include_pad=True, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Pool
Pooling layer for neural networks.
Examples
from neurai.nn import AvgPool from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pool = AvgPool(kernel_size=(2, 2)) param = pool.init(input) out = pool.run(param, input)
- Parameters:
init_val (scalar) – Initial value for the reduction.
func (callable) – Reduction function to use.
kernel_size (Union[int, Sequence[int]]) – Window size for the pooling operation.
strides (Union[None, int, Sequence[int]]) – Strides for the pooling operation.
padding (Union[str, Sequence[Tuple[int, int]]]) – Padding for the pooling operation.
count_include_pad (a boolean whether to include padded tokens) – in the average calculation (default: True).
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.pooling.MaxPool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Pool
Pools the input by taking the maximum over a window.
Examples
from neurai.nn import MaxPool from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pool = MaxPool(kernel_size=(2, 1)) param = pool.init(input) out = pool.run(param, input)
- Parameters:
kernel_size (int, sequence of int) – An integer, or a sequence of integers defining the window to reduce over.
strides (int, sequence of int) – An integer, or a sequence of integers, representing the inter-window strides (default: (1, …, 1)).
padding (str, sequence of tuple) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to run before and after each spatial dimension.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- func(y)#
Elementwise maximum: \(\mathrm{max}(x, y)\)
For complex numbers, uses a lexicographic comparison on the (real, imaginary) pairs.
- class neurai.nn.layer.pooling.MinPool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Pool
Pools the input by taking the average over a window.
Examples
from neurai.nn import MinPool from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3)) pool = MinPool(kernel_size=(1, 2)) param = pool.init(input) out = pool.run(param, input)
- Parameters:
kernel_size (int, sequence of int) – An integer, or a sequence of integers defining the window to reduce over.
strides (int, sequence of int) – An integer, or a sequence of integers, representing the inter-window strides (default: (1, …, 1)).
padding (str, sequence of tuple) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to run before and after each spatial dimension.
count_include_pad (a boolean whether to include padded tokens) – in the average calculation (default: True).
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- func(y)#
Elementwise minimum: \(\mathrm{min}(x, y)\)
For complex numbers, uses a lexicographic comparison on the (real, imaginary) pairs.
- class neurai.nn.layer.pooling.Pool(kernel_size, strides=None, padding='VALID', parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Pooling layer for neural networks.
- Parameters:
init_val (scalar) – Initial value for the reduction.
func (callable) – Reduction function to use.
kernel_size (Union[int, Sequence[int]]) – Window size for the pooling operation.
strides (Union[None, int, Sequence[int]]) – Strides for the pooling operation.
padding (Union[str, Sequence[Tuple[int, int]]]) – Padding for the pooling operation.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.pooling.SNNPool(features=1, kernel_size=(2, 2, 1), strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=False, mask=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function SNNPool.<lambda>>, bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Conv3d
SNN pooling layer.
Examples
from neurai.nn import SNNPool from jax import random input = random.normal(random.PRNGKey(0), (1, 28, 28, 3, 6)) pool = SNNPool(kernel_size=(2, 2, 1)) param = pool.init(input) out = pool.run(param, input)
- Parameters:
features (int.) – Number of output channels.
kernel_size (int, sequence of int) – The shape of the convolutional kernel.
kernel_col (str.) – Specify whether weight can be learned. The default is const, indicating that it is not learnable.
kernel_init (Callable, Optional) – The initializer for the convolutional kernel.
strides (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, representing the inter-window strides. Default is 1.
padding (Union[str, Tuple[int, int], Sequence[Tuple[int, int]]], Optional) – Either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. Default is ‘SAME’.
input_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Default is 1. Convolution with input dilation d is equivalent to transposed convolution with stride d.
kernel_dilation (Union[None, int, Sequence[int]], Optional) – An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Default is 1. Convolution with kernel dilation is also known as ‘atrous convolution’.
feature_group_count (int, Optional) – If specified, divides the input features into feature_group_count. Default is 1.
use_bias (bool, Optional) – Whether to add a bias term to the output. Default is True.
mask (Optional[jnp.ndarray], Optional) – The optional mask of the weights. Default is None.
param_dtype (Any, Optional) – The dtype passed to parameter initializers. Default is jnp.float32.
precision (PrecisionLike, Optional) – The numerical precision of the computation. See jax.lax.Precision for details. Default is None.
bias_init (Callable, Optional) – The initializer for the bias. Default is KaimingUniformIniter with scale=math.sqrt(5) and distribution=’uniform_no_variance’.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- bias_init(shape, dtype=<class 'int'>, func=<function <lambda>>)#
Initializes the VarianceScalingIniter.
- Parameters:
scale (float) – the value to scale the variance.
mode (str) – indicates how to calculate the variance scaling factor.
distribution (str) – indicates the type of distribution to use.
in_axis (int) – indicates the input axis for computing the variance scaling factor, by default -2.
out_axis (int) – indicates the output axis for computing the variance scaling factor, by default -1.
- kernel_init(shape, dtype)#
Initializes the VarianceScalingIniter.
- Parameters:
scale (float) – the value to scale the variance.
mode (str) – indicates how to calculate the variance scaling factor.
distribution (str) – indicates the type of distribution to use.
in_axis (int) – indicates the input axis for computing the variance scaling factor, by default -2.
out_axis (int) – indicates the output axis for computing the variance scaling factor, by default -1.
- class neurai.nn.layer.pooling.UpSampleNearest(size=None, scale_factor=None, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Applies a nearest neighbor upsampling to an input signal composed of several input channels.
Examples
from neurai.nn import UpSampleNearest from jax import random input = random.normal(random.PRNGKey(0), (1, 12, 1)) pool = UpSampleNearest(size=(1, 12, 1)) param = pool.init(input) out = pool.run(param, input)
- Parameters:
size (Union[int, Tuple[int, int], Tuple[int, int, int]]) – Output spatial sizes.
scale_factor (Union[int, Tuple[int, int], Tuple[int, int, int]]) – Multiplier for spatial size. Has to match input size if it is a tuple. The dimension of this value does not include batch size and channel.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.ALIFCell(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
SRNNCellBase
The neuron is SNNALIF.
- Parameters:
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.GRUCell(input_size, hidden_size, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
RNNCellBase
GRU cell.
Examples
from neurai.nn import RNN, GRUCell from jax import random input = random.normal(random.PRNGKey(0), (28, 512, 28)) rnn = RNN(GRUCell(input_size=28, hidden_size=128)) param = rnn.init(input) out = rnn.run(param, input)
- Parameters:
input_size (int) – input size
hidden_size (int) – hidden layer size
kernel_init (Callable) – initializer function for the kernels that transform the input, by default KaimingUniformIniter()
recurrent_kernel_init (Callable) – initializer function for the kernels that transform the hidden state, by default KaimingUniformIniter()
bias_init (Callable) – initializer for the bias parameters
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.LIFCell(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
SRNNCellBase
The neuron is SNNLIF.
- Parameters:
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.LSTMCell(input_size, hidden_size, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
RNNCellBase
LSTM cell.
Examples
from neurai.nn import RNN, LSTMCell from jax import random input = random.normal(random.PRNGKey(0), (28, 512, 28)) lstm = RNN(LSTMCell(input_size=28, hidden_size=128)) param = lstm.init(input) out = lstm.run(param, input)
- Parameters:
input_size (int) – input size
hidden_size (int) – hidden layer size
kernel_init (Callable) – initializer function for the kernels that transform the input, by default KaimingUniformIniter()
recurrent_kernel_init (Callable) – initializer function for the kernels that transform the hidden state, by default KaimingUniformIniter()
bias_init (Callable) – initializer for the bias parameters
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.RNN(cell, batch_first=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
The
RNN
module takes anyRNNCellBase
instance and applies it over a sequence.Examples
from neurai.nn import RNN, GRUCell from jax import random input = random.normal(random.PRNGKey(0), (28, 512, 28)) rnn = RNN(GRUCell(input_size=28, hidden_size=128)) param = rnn.init(input) out = rnn.run(param, input)
- Parameters:
cell (RNNCellBase) – an instance of
RNNCellBase
.batch_first (bool) – Whether batch is the first dimension, the default is false
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.RNNCellBase(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
RNN cell base class, this class defines the basic functionality that every cell should implement.
- Parameters:
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.recurrent.SRNN(neuron=None, dense_w=None, dense_r=None, out_dim=None, time_dim=-1, bias=False, kernel_init=KaimingUniformIniter(key=None), recurrent_kernel_init=KaimingUniformIniter(key=None), bias_init=<function zero_func>, param_dtype=<class 'jax.numpy.float32'>, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Spiking Recurrent Neural Network (SRNN) module for sequential data processing.
Examples
from neurai.nn import SNNLIF, SRNN from jax import random input = random.normal(random.PRNGKey(0), (32, 28, 28)) snnlif = SNNLIF(tau=2.0) srnn = SRNN(snnlif, out_dim=1024) param = srnn.init(input) out = srnn.run(param, input)
- Parameters:
neuron (Callable) – The spiking neuron model used in the SRNN.
dense_w (Callable, Optional) – The dense weights for input transformation.
dense_r (Callable, Optional) – The dense weights for recurrent connections.
time_dim (int, Optional) – The time dimension. Default is -1.
bias (bool, Optional) – Flag to determine if bias is used. Default is False.
kernel_init (Callable, Optional) – The initializer for kernel weights. Default is KaimingUniformIniter(scale=math.sqrt(5), distribution=’leaky_relu’).
recurrent_kernel_init (Callable, Optional) – The initializer for recurrent kernel weights. Default is KaimingUniformIniter(scale=math.sqrt(5), distribution=’leaky_relu’).
bias_init (Callable, Optional) – The initializer for bias. Default is zero_func.
param_dtype (Any, Optional) – The data type for parameters. Default is jnp.float32.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- Returns:
(spike, mem) or spike – If the neuron property record_v is true, return spike and mem, otherwise return spike
- class neurai.nn.layer.recurrent.SRNNCellBase(parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
SRNN cell base class, this class defines the basic functionality that every cell should implement
- Parameters:
- initial_state(shape, neuron, dense_w, dense_r, param_dtype)#
Initialize the state of the SRNN cell.
- Parameters:
shape (tuple) – Shape of the cell state.
neuron (Neuron) – The neuron model for the cell.
dense_w (Callable) – Function for dense weights.
dense_r (Callable) – Function for recurrent weights.
param_dtype (Any) – Data type for cell parameters.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- class neurai.nn.layer.slayer.Slayer(size=1, pid=0, V_th=10.0, V_reset=0.0, stdp_tau=20.0, tau_triplet=110.0, dendritic_delay=1, area_name=None, surrogate_grad_fn=<neurai.grads.surrogate_grad.SlayerPdf object>, id=0, neuron_start_id=1, V_rest=0.0, tau_res=10.0, tau_ref=1.0, time_step=1.0, time_window=100, kernel_window=38, record_v=False, computation_type=ComputationType.CUDA, scale_rho=2, tau_rho=1, scale_ref=2, full_ref_kernel=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
SNNSRM
SLAYER (Spike Layer Error Reassignment) algorithm. The base description of the framework has been published in [NeurIPS 2018] (https://nips.cc/Conferences/2018/Schedule?showEvent=11157).The final paper is available [here](http://papers.nips.cc/paper/7415-slayer-spike-layer-error-reassignment-in-time.pdf). The arXiv preprint is available [here](https://arxiv.org/abs/1810.08646).
Examples
from neurai.config import set_platform from neurai.nn import Slayer from jax import random, numpy as jnp set_platform("gpu") slayer = Slayer(V_th=10.0, time_step=1.0, time_window=100, scale_rho=2, tau_rho=1) input = random.randint(random.PRNGKey(0), shape=(1, 2, 34, 34, 100), minval=0, maxval=2).astype(jnp.float32) param = slayer.init(input) out = slayer.run(param, input)
- Parameters:
V_rest (float) – the resting membrane potential, by default 0.
tau_res (float) – membrane time constant, by default 10.
tau_ref (float) – neuron refractory time constant, by default 1.
time_step (float) – sampling time, by default 1.
time_window (int) – time length of sample, by default 100.
scale_rho (int, float) – spike function derivative scale factor, by default 2.
tau_rho (int, float) – spike function derivative time constant (relative to theta), by default 1.
scale_ref (int, float) – neuron refractory response scaling (relative to theta)
full_ref_kernel (bool) – optional, high resolution refractory kernel (the user shall not use it in practice).
record_v (bool) – Whether to record the membrane potential values. Setting it to True will output both spike and membrane potential values, while setting it to False will only output spike values.
surrogate_grad_fn (Callable) – the function for calculating surrogate gradients, by default surrogate_grad.Sigmoid().
computation_type (str) – computation_type, available_types = [cuda, for_loop], slayer currently only supports for_loop, cuda computation type. slayer’s calculation type defaults to cuda.
kernel_window (int) – Calculate the maximum number of elements in the response kernel based on tau_res.
V_reset (float, Optional) – Reset voltage of this neurons layer, If it is not None, the voltage will be reset to V_reset after firing a spikie. If set to None, the voltage will be subtracted by V_th after firing a spikie, by default 0.
size (int, Optional) – The number of neurons in the population, by default 1.
stdp_tau (float = 20.) –
tau_triplet (float, Optional) – Time constant of long presynaptic trace, by default 110.
dendritic_delay (Union[int, jax.numpy.ndarray, Initializer, Callable], Optional) – The dendritic delay length, by default 1
pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.
area_name (str, Optional) – The name of the area to which the current neuron belongs. By default empty string.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
id (
int
) –neuron_start_id (
int
) –
- refractory_kernel(key=0, kernel_type='alpha')#
Calculate the refractory kernel. The refractory kernel is given by the following formula:
alpha kernel:
\[\nu(t) = -\theta \frac{t}{ \tau} \exp(1 - \frac{t}{\tau})\]exp kernel:
\[\nu(t) = \frac{\theta}{ \tau} \exp(1 - \frac{t}{\tau})\]- Returns:
list
- response_kernel(key=0, kernel_type='alpha')#
Calculate the PSP kernel. The PSP kernel is given by the following formula:
alpha kernel:
\[\epsilon(t) = \frac{t}{ \tau} \exp(1 - \frac{t}{\tau})\]exp kernel:
\[\epsilon(t) = \frac{1}{ \tau} \exp(1 - \frac{t}{\tau})\]- Returns:
list
- neurai.nn.layer.slayer.psp_forward_jvp(primals, tangents)#
Compute the Jacobian-vector product (JVP) for the custom JAX primitive psp_forward.
- Parameters:
primals – The primal inputs for the psp_forward primitive.
tangents – The tangents for the JVP computation.
- Returns:
tuple: A tuple containing the JVP for psp_forward and the JVP for the membrane potential.