neurai.nn package#
Subpackages#
- neurai.nn.conn package
- neurai.nn.layer package
- Submodules
Acitivate
Celu
Elu
Gelu
Glu
HardSigmoid
HardSilu
HardSwish
HardTanh
LeakyRelu
LogSigmoid
LogSoftmax
LogSumexp
Mish
Relu
Relu6
Selu
Sigmoid
Silu
SoftSign
Softmax
Softmax2D
Softmin
Softplus
Tanh
Embed
canonicalize_dtype()
masked_softmax()
promote_dtype()
sequence_mask()
transpose_output()
transpose_qkv()
Conv
Conv1d
Conv2d
Conv3d
ConvTranspose
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
SNNConv2d
SNNConv3d
SNNConvTranspose3d
canonicalize_padding()
maybe_replicate()
Dropout
SNNDropout3d
Exodus
exodus_cpu_bwd()
exodus_cpu_fwd()
Flatten
Linear
SNNLinear3d
Sequential
BatchNorm
BatchNorm1d
BatchNorm2d
BatchNorm3d
LayerNorm
TdBatchNorm
TdLayer
WeightNorm
ConstantPad
ReflectionPad
ReplicationPad
ZeroPad
AvgPool
MaxPool
MinPool
Pool
SNNPool
UpSampleNearest
ALIFCell
GRUCell
LIFCell
LSTMCell
RNN
RNNCellBase
SRNN
SRNNCellBase
Slayer
psp_forward_jvp()
- Module contents
- Submodules
- neurai.nn.neuron package
- neurai.nn.rlayer package
- neurai.nn.synapse package
Submodules#
- class neurai.nn.delay_spike.DelaySpike(max_delay_step=0, spike_size=1, init_data=None, init_val=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
Module
Maintain a matrix consisting of every-time-step spikes and manage the spike emitting at every time step.
- Parameters:
max_delay_step (Union[int, jnp.ndarray, np.ndarray], Optional) – Max value of delay step of spikes transmitting between pre and post neuron group, by default 0.
spike_size (int, Optional) – The size of the spike matrix, by default 1.
init_data (jnp.ndarray, Optional) – An optional initial spike matrix data to initialize the matrix, by default None.
init_val (bool, Optional) – Whether to initialize using the initial data. by default False.
pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel
.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- reset()#
Reset the DelaySpike instance to its initial state.
This function sets the data matrix to all zeros, the current_spiking_index to 0.
- neurai.nn.functional.flatten(input, start_dim=1, end_dim=-1)#
Initializes a Flatten module that flattens a jnp.ndarray along the specified dimensions.
- Parameters:
- Return type:
- Returns:
jnp.ndarray – The flattened data.
- neurai.nn.functional.interpolate(input, size=None, scale_factor=None, mode='bilinear', align_corners=False)#
Interpolate the input tensor to a specified size or scale.
- Parameters:
input (jnp.ndarray) – The input tensor to be interpolated.
size (tuple or int, Optional) – The target size for interpolation. If provided as a single integer, it’s interpreted as (size, size). Either ‘size’ or ‘scale_factor’ must be provided.
scale_factor (float or tuple, Optional) – The factor by which the input should be scaled. If provided as a single float, it’s interpreted as (scale_factor, scale_factor). Either
size
orscale_factor
must be provided.mode (str, Optional) – The interpolation mode. Supported modes are
nearest
,linear
,bilinear
,bicubic
, andtrilinear
.align_corners (bool, Optional) – A flag indicating whether to align the corners of the input and output when using bilinear interpolation. Only applicable when mode is
bilinear
.
- Returns:
jnp.ndarray – The interpolated tensor.
- Raises:
ValueError – If neither
size
norscale_factor
is sprovided, or if an unsupported interpolationmode
is specified.
Examples
input = jnp.array([[[[1.0, 2.0], [3.0, 4.0]]]]) output = interpolate(input, size=(4, 4), mode='bilinear', align_corners=False) print(output)
Note
When
mode
isnearest
, the function uses nearest-neighbor interpolation.When
mode
islinear
,bilinear
,bicubic
, ortrilinear
, the function uses linear interpolation.align_corners
is relevant only forbilinear
interpolation, and it specifies whether to align the corners of the input and output grids. Setting it to True can give more accurate results when aligning grid corners, but it may not be suitable for all use cases.
- class neurai.nn.loss.SpikeLoss(model=None, start_time=0, end_time=100, positive=60, negative=10, time_step=1.0, time_windows=100, psp_fn=None, **psp_fn_args)#
Bases:
object
This class defines different spike based loss modules that can be used to optimize the SNN.
- Parameters:
model (Callable) – The neural network model.
start_time (int, Optional) – For target region startID, by default 0.
end_time (int, Optional) – For target region stopID, by default 100.
positive (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is true. By default 60.
negative (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is false. By default 10.
time_windows (int, Optional) – time length of sample, by default 100.
psp_fn (Callable) – calculates the error based on the difference between the actual spike activity (spikeOut) and the desired spike activity (spikeDesired).
psp_fn_args – psp_fn function parameter
- numSpikes(predict, target, **kwargs)#
Calculates spike loss based on number of spikes within a target region. For classification tasks, a decision is typically made based on the number of output spikes during an interval rather than the precise timing of the spikes. To handle such cases, the error signal during the interval can be defined as:
\[e^{(n_l)}(t):= ( \int_{T_{int}} S^{(n_l)}( au)d au - \int_{T_{int}} \hat{S}( au)d au), t \in T_{int}\]and zero outside the interval \(T_{int}\).
- Parameters:
predict (jnp.ndarray) – spike
target (jnp.ndarray) – one-hot encoded desired class. Time dimension should be 1 and rest of the dimensions should be same as
predict
.kwargs (Any) – any additional keyword arguments, such as numSpikesScale.
- spikeTime(ps, batch_data)#
Calculates spike loss based on spike time. Consider a loss function for the network in time interval t ∈ [0, T], defined as:
\[E:= \int_0^T L(S^{(n_l)}(t), \hat{S}(t)) d{t} = \frac{1}{2}\int_0^T (e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t)))^2 d{t}\]where \(\hat{S}(t)\) is the target spike train, \(L(S^{(n_l)}(t)\), \(\hat{S}(t))\) is the loss at time instance \(t\) and \(e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t))\) is the error singale at final layer. For brevity we will write the error signal as \(e^{(n_l)}(t)\) from here on.
To learn a target spike train \(\hat{S}(t)\) an error signal of the form:
\[e^{(n_l)}(t):= \varepsilon(t) * (S^{(n_l)}(t) - \hat{S}(t))\]The loss is similar to van Rossum distance between output and desired spike train. Where \(\Theta(t)\) is the Heaviside step function.
\[\varepsilon(t) = \frac{t}{\tau_s} e^{1 - \frac{t}{\tau_s}} \Theta(t)\]
- neurai.nn.loss.binary_cross_entropy(pred, label, weight=None, reduction='mean')#
Calculates the binary_cross_entropy loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted between 0 and 1 as a numpy array.
label (jnp.ndarray) – The true labels between 0 and 1 as a numpy array.
weight (jnp.ndarray, Optional) – Manual rescaling weight, it match the shape of input.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> logits = jnp.array([0.5, 0.6, 0.7, 0.8, 0.9]) >>> labels = jnp.array([0, 1, 0, 1, 0]) >>> binary_cross_entropy(logits, labels, reduction='mean') array(0.9867)
- neurai.nn.loss.hinge_loss(pred, label)#
Calculates the hinge loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([0.8, -0.4, 1.2]) >>> label = np.array([1, -1, 1]) >>> hinge_loss(pred, label) array(0.26666667)
- neurai.nn.loss.huber_loss(logits, labels, reduction='mean', delta=1.0)#
Calculates the huber loss between predicted and true labels.
- Parameters:
logits (jnp.ndarray) – The predicted labels as a numpy array.
labels (jnp.ndarray) – The true labels as a numpy array.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
delta (float, Optional) – The threshold parameter for huber loss. Default is 1.0.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([0.8, -0.4, 1.2]) >>> label = np.array([1, -1, 1]) >>> huber_loss(pred, label) Array(0.07333334, dtype=float32)
- neurai.nn.loss.l1_loss(logits, labels, reduction='mean')#
Calculates the L1 loss between predicted logits and true labels.
- Parameters:
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> logits = np.array([0.8, -0.4, 1.2]) >>> labels = np.array([1, -1, 1]) >>> l1_loss(logits, labels) array(0.33333334)
- neurai.nn.loss.mse_loss(pred, label)#
Calculates the mean squared error (MSE) loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([1, 2, 3]) >>> label = np.array([2, 4, 6]) >>> mse_loss(pred, label) array(4.66666667)
- neurai.nn.loss.sigmoid_binary_cross_entropy(pred, label)#
Calculates the sigmoid_binary_cross_entropy loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
- neurai.nn.loss.smooth_l1_loss(logits, labels, reduction='mean', beta=1.0)#
Calculates the smooth L1 loss between predicted logits and true labels.
- Parameters:
logits (jax.jnp.ndarray) – The predicted logits as a JAX NumPy array.
labels (jax.jnp.ndarray) – The true labels as a JAX NumPy array.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
beta (float, Optional) – The threshold parameter for smooth L1 loss. Default is 1.0.
- Returns:
jax.jnp.ndarray – The calculated loss value.
Examples
>>> logits = jnp.array([0.8, -0.4, 1.2]) >>> labels = jnp.array([1, -1, 1]) >>> smooth_l1_loss(logits, labels) array(0.07333334)
- neurai.nn.loss.softmax_cross_entropy(pred, label)#
Calculates the cross entropy softmax loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
from jax import numpy as jnp from neurai.nn.loss import softmax_cross_entropy pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]]) label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) softmax_cross_entropy(pred, label)
expect output: Array(1.0599941, dtype=float32)
- neurai.nn.loss.sparse_softmax_cross_entropy(pred, label)#
Calculates the cross entropy softmax loss between predicted and true labels without one-hot label.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
from jax import numpy as jnp from neurai.nn.loss import sparse_softmax_cross_entropy pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]]) label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) sparse_softmax_cross_entropy(pred, label)
expect output: Array(1.0599941, dtype=float32)