neurai.nn package#

Subpackages#

Submodules#

class neurai.nn.delay_spike.DelaySpike(max_delay_step=0, spike_size=1, init_data=None, init_val=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#

Bases: Module

Maintain a matrix consisting of every-time-step spikes and manage the spike emitting at every time step.

Parameters:
  • max_delay_step (Union[int, jnp.ndarray, np.ndarray], Optional) – Max value of delay step of spikes transmitting between pre and post neuron group, by default 0.

  • spike_size (int, Optional) – The size of the spike matrix, by default 1.

  • init_data (jnp.ndarray, Optional) – An optional initial spike matrix data to initialize the matrix, by default None.

  • init_val (bool, Optional) – Whether to initialize using the initial data. by default False.

  • pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.

  • name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.

  • parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default neurai.nn.module._Sentinel .

  • frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.

reset()#

Reset the DelaySpike instance to its initial state.

This function sets the data matrix to all zeros, the current_spiking_index to 0.

neurai.nn.functional.flatten(input, start_dim=1, end_dim=-1)#

Initializes a Flatten module that flattens a jnp.ndarray along the specified dimensions.

Parameters:
  • input (jnp.ndarray) – The input data.

  • start_dim (int, Optional) – The starting dimension to flatten, by default 1

  • end_dim (int, Optional) – The ending dimension to flatten, by default -1. If -1, all dimensions starting from start_dim are flattened.

Return type:

Array

Returns:

jnp.ndarray – The flattened data.

neurai.nn.functional.interpolate(input, size=None, scale_factor=None, mode='bilinear', align_corners=False)#

Interpolate the input tensor to a specified size or scale.

Parameters:
  • input (jnp.ndarray) – The input tensor to be interpolated.

  • size (tuple or int, Optional) – The target size for interpolation. If provided as a single integer, it’s interpreted as (size, size). Either ‘size’ or ‘scale_factor’ must be provided.

  • scale_factor (float or tuple, Optional) – The factor by which the input should be scaled. If provided as a single float, it’s interpreted as (scale_factor, scale_factor). Either size or scale_factor must be provided.

  • mode (str, Optional) – The interpolation mode. Supported modes are nearest, linear, bilinear, bicubic, and trilinear.

  • align_corners (bool, Optional) – A flag indicating whether to align the corners of the input and output when using bilinear interpolation. Only applicable when mode is bilinear.

Returns:

jnp.ndarray – The interpolated tensor.

Raises:

ValueError – If neither size nor scale_factor is sprovided, or if an unsupported interpolation mode is specified.

Examples

input = jnp.array([[[[1.0, 2.0], [3.0, 4.0]]]])
output = interpolate(input, size=(4, 4), mode='bilinear', align_corners=False)
print(output)

Note

  • When mode is nearest, the function uses nearest-neighbor interpolation.

  • When mode is linear, bilinear, bicubic, or trilinear, the function uses linear interpolation.

  • align_corners is relevant only for bilinear interpolation, and it specifies whether to align the corners of the input and output grids. Setting it to True can give more accurate results when aligning grid corners, but it may not be suitable for all use cases.

class neurai.nn.loss.SpikeLoss(model=None, start_time=0, end_time=100, positive=60, negative=10, time_step=1.0, time_windows=100, psp_fn=None, **psp_fn_args)#

Bases: object

This class defines different spike based loss modules that can be used to optimize the SNN.

Parameters:
  • model (Callable) – The neural network model.

  • start_time (int, Optional) – For target region startID, by default 0.

  • end_time (int, Optional) – For target region stopID, by default 100.

  • positive (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is true. By default 60.

  • negative (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is false. By default 10.

  • time_step (float, Optional) – sampling time, by default 1.

  • time_windows (int, Optional) – time length of sample, by default 100.

  • psp_fn (Callable) – calculates the error based on the difference between the actual spike activity (spikeOut) and the desired spike activity (spikeDesired).

  • psp_fn_args – psp_fn function parameter

numSpikes(predict, target, **kwargs)#

Calculates spike loss based on number of spikes within a target region. For classification tasks, a decision is typically made based on the number of output spikes during an interval rather than the precise timing of the spikes. To handle such cases, the error signal during the interval can be defined as:

\[e^{(n_l)}(t):= ( \int_{T_{int}} S^{(n_l)}( au)d au - \int_{T_{int}} \hat{S}( au)d au), t \in T_{int}\]

and zero outside the interval \(T_{int}\).

Parameters:
  • predict (jnp.ndarray) – spike

  • target (jnp.ndarray) – one-hot encoded desired class. Time dimension should be 1 and rest of the dimensions should be same as predict.

  • kwargs (Any) – any additional keyword arguments, such as numSpikesScale.

spikeTime(ps, batch_data)#

Calculates spike loss based on spike time. Consider a loss function for the network in time interval t ∈ [0, T], defined as:

\[E:= \int_0^T L(S^{(n_l)}(t), \hat{S}(t)) d{t} = \frac{1}{2}\int_0^T (e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t)))^2 d{t}\]

where \(\hat{S}(t)\) is the target spike train, \(L(S^{(n_l)}(t)\), \(\hat{S}(t))\) is the loss at time instance \(t\) and \(e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t))\) is the error singale at final layer. For brevity we will write the error signal as \(e^{(n_l)}(t)\) from here on.

To learn a target spike train \(\hat{S}(t)\) an error signal of the form:

\[e^{(n_l)}(t):= \varepsilon(t) * (S^{(n_l)}(t) - \hat{S}(t))\]

The loss is similar to van Rossum distance between output and desired spike train. Where \(\Theta(t)\) is the Heaviside step function.

\[\varepsilon(t) = \frac{t}{\tau_s} e^{1 - \frac{t}{\tau_s}} \Theta(t)\]
Parameters:
  • ps (list) – The current values of the network parameters.

  • batch_data (tuple) – The input, output, label data for a mini-batch.

neurai.nn.loss.binary_cross_entropy(pred, label, weight=None, reduction='mean')#

Calculates the binary_cross_entropy loss between predicted and true labels.

Parameters:
  • pred (jnp.ndarray) – The predicted between 0 and 1 as a numpy array.

  • label (jnp.ndarray) – The true labels between 0 and 1 as a numpy array.

  • weight (jnp.ndarray, Optional) – Manual rescaling weight, it match the shape of input.

  • reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.

Returns:

jnp.ndarray – The calculated loss value.

Examples

>>> logits = jnp.array([0.5, 0.6, 0.7, 0.8, 0.9])
>>> labels = jnp.array([0, 1, 0, 1, 0])
>>> binary_cross_entropy(logits, labels, reduction='mean')
array(0.9867)
neurai.nn.loss.hinge_loss(pred, label)#

Calculates the hinge loss between predicted and true labels.

Parameters:
  • pred (jnp.ndarray) – The predicted labels as a numpy array.

  • label (jnp.ndarray) – The true labels as a numpy array.

Returns:

jnp.ndarray – The calculated loss value.

Examples

>>> pred = np.array([0.8, -0.4, 1.2])
>>> label = np.array([1, -1, 1])
>>> hinge_loss(pred, label)
array(0.26666667)
neurai.nn.loss.huber_loss(logits, labels, reduction='mean', delta=1.0)#

Calculates the huber loss between predicted and true labels.

Parameters:
  • logits (jnp.ndarray) – The predicted labels as a numpy array.

  • labels (jnp.ndarray) – The true labels as a numpy array.

  • reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.

  • delta (float, Optional) – The threshold parameter for huber loss. Default is 1.0.

Returns:

jnp.ndarray – The calculated loss value.

Examples

>>> pred = np.array([0.8, -0.4, 1.2])
>>> label = np.array([1, -1, 1])
>>> huber_loss(pred, label)
Array(0.07333334, dtype=float32)
neurai.nn.loss.l1_loss(logits, labels, reduction='mean')#

Calculates the L1 loss between predicted logits and true labels.

Parameters:
  • logits (jnp.ndarray) – The predicted logits as a numpy array.

  • labels (jnp.ndarray) – The true labels as a numpy array.

  • reduction (str, Optional) – The reduction type for the loss calculation. Possible values are ‘mean’, ‘sum’, or ‘none’. Default is ‘mean’.

Returns:

jnp.ndarray – The calculated loss value.

Examples

>>> logits = np.array([0.8, -0.4, 1.2])
>>> labels = np.array([1, -1, 1])
>>> l1_loss(logits, labels)
array(0.33333334)
neurai.nn.loss.mse_loss(pred, label)#

Calculates the mean squared error (MSE) loss between predicted and true labels.

Parameters:
  • pred (jnp.ndarray) – The predicted labels as a numpy array.

  • label (jnp.ndarray) – The true labels as a numpy array.

Returns:

jnp.ndarray – The calculated loss value.

Examples

>>> pred = np.array([1, 2, 3])
>>> label = np.array([2, 4, 6])
>>> mse_loss(pred, label)
array(4.66666667)
neurai.nn.loss.sigmoid_binary_cross_entropy(pred, label)#

Calculates the sigmoid_binary_cross_entropy loss between predicted and true labels.

Parameters:
  • pred (jnp.ndarray) – The predicted labels as a numpy array.

  • label (jnp.ndarray) – The true labels as a numpy array.

Returns:

jnp.ndarray – The calculated loss value.

neurai.nn.loss.smooth_l1_loss(logits, labels, reduction='mean', beta=1.0)#

Calculates the smooth L1 loss between predicted logits and true labels.

Parameters:
  • logits (jax.jnp.ndarray) – The predicted logits as a JAX NumPy array.

  • labels (jax.jnp.ndarray) – The true labels as a JAX NumPy array.

  • reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.

  • beta (float, Optional) – The threshold parameter for smooth L1 loss. Default is 1.0.

Returns:

jax.jnp.ndarray – The calculated loss value.

Examples

>>> logits = jnp.array([0.8, -0.4, 1.2])
>>> labels = jnp.array([1, -1, 1])
>>> smooth_l1_loss(logits, labels)
array(0.07333334)
neurai.nn.loss.softmax_cross_entropy(pred, label)#

Calculates the cross entropy softmax loss between predicted and true labels.

Parameters:
  • pred (jnp.ndarray) – The predicted labels as a numpy array.

  • label (jnp.ndarray) – The true labels as a numpy array.

Returns:

jnp.ndarray – The calculated loss value.

Examples

from jax import numpy as jnp
from neurai.nn.loss import softmax_cross_entropy

pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]])
label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
softmax_cross_entropy(pred, label)

expect output: Array(1.0599941, dtype=float32)

neurai.nn.loss.sparse_softmax_cross_entropy(pred, label)#

Calculates the cross entropy softmax loss between predicted and true labels without one-hot label.

Parameters:
  • pred (jnp.ndarray) – The predicted labels as a numpy array.

  • label (jnp.ndarray) – The true labels as a numpy array.

Returns:

jnp.ndarray – The calculated loss value.

Examples

from jax import numpy as jnp
from neurai.nn.loss import sparse_softmax_cross_entropy

pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]])
label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
sparse_softmax_cross_entropy(pred, label)

expect output: Array(1.0599941, dtype=float32)

class neurai.nn.module.SetupState(value)#

Bases: IntEnum

An enumeration.

Module contents#