脉冲神经网络构建和训练#

作者: Liuhui & Jiangbo

与主流深度学习构建和训练过程类似,SNN模型的构建和训练如下:

  1. 定义数据集

  • 您可以选择数据集,例如:

from neurai.datasets.dataloader import DataLoader
from neurai.datasets.n_mnist import NMNIST

time_windows = 50

train_data = NMNIST(DATASETS_DIR, train=False, sampling_time=1.0, sample_length=time_windows, download=False)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, enable_jit=True, drop_last=True)
  • 或者随机生成数据,例如:

import jax
from jax import numpy as jnp

import jax
from jax import numpy as jnp

seed = jax.random.PRNGKey(2023)
proportion = 0.8
data_shape = (32, 2, 34, 34)
spike_data = jnp.asarray(jax.random.bernoulli(seed, p=proportion, shape=data_shape).astype(jnp.float32))
  1. 定义SNN模型

from neurai.nn.module import Module
from neurai.nn.neuron import SNNLIF
from neurai.nn.layer import Linear, Flatten
from neurai.grads import surrogate_grad
import neurai.const as nc


class SNNMLP(Module):

  def setup(self):
    self.lif = SNNLIF(step_mode=nc.single, surrogate_grad_fn=surrogate_grad.Rectangular())
    self.linear = Linear(10, bias=False)
    self.flatten = Flatten()

  def __call__(self, input):
    y = self.flatten(input)
    y = self.linear(y)
    y = self.lif(y)
    return y
  1. 定义损失函数、优化器函数和评估函数,例如:

model = SNNMLP()
ps = model.init(input=spike_data)

from neurai.opt import adam, apply_updates
from neurai.nn.loss import mse_loss
from neurai.grads import grad

opt = adam(learning_rate=0.001)
opt_state = opt.init(ps)
model = model.run

def loss_f(model_param, batch, label):
  predict = 0.
  for t in range(time_windows):
    predict += model(model_param, batch[0][..., t])
  predict = predict / time_windows
  loss_value = mse_loss(predict, label)
  return loss_value, predict
  1. 训练

from jax import jit
from jax import numpy as jnp, nn as jnn

@jit
def train_update(model_param, batch, opt_state):
  """
  This is a update function for SNN model.

  Parameters:
      loss_fn (class method): Function for calculating model losses.
      model_param (dict): Parameter Dictionary of the Model.
      batch (tuple): Tuple unpacking e.g. input, desiredClass, label = batch.
  """
  grad_fn = grad(loss_f, has_aux=True, return_fun_value=True)
  label = jnn.one_hot(batch[-1], 10)
  grads, loss_value, predict = grad_fn(model_param, batch, label)
  accuracy = (jnp.argmax(predict, 1) == batch[-1]).sum()
  updates, opt_state = opt.update(grads, opt_state, model_param)
  ps = apply_updates(model_param, updates)
  return grads, loss_value, accuracy, ps, opt_state


from tqdm import tqdm

Epochs = 5
for ep in range(Epochs):
  train_acc = 0
  train_loss = 0
  trainnum = 0
  train_loop = tqdm(train_loader, desc="Train")
  train_loop.set_description(f"Epoch: [{ep+1} / {Epochs}]")
  for st, batch in enumerate(train_loop):

    grads, loss_value, accuracy, ps, opt_state = train_update(ps, batch, opt_state)
    train_acc += accuracy
    train_loss += loss_value
    trainnum += len(batch[1])
    train_loop.set_postfix(accuracy=train_acc / trainnum, loss=train_loss / trainnum)