脉冲神经网络构建和训练#
与主流深度学习构建和训练过程类似,SNN模型的构建和训练如下:
定义数据集
您可以选择数据集,例如:
from neurai.datasets.dataloader import DataLoader
from neurai.datasets.n_mnist import NMNIST
time_windows = 50
train_data = NMNIST(DATASETS_DIR, train=False, sampling_time=1.0, sample_length=time_windows, download=False)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, enable_jit=True, drop_last=True)
或者随机生成数据,例如:
import jax
from jax import numpy as jnp
import jax
from jax import numpy as jnp
seed = jax.random.PRNGKey(2023)
proportion = 0.8
data_shape = (32, 2, 34, 34)
spike_data = jnp.asarray(jax.random.bernoulli(seed, p=proportion, shape=data_shape).astype(jnp.float32))
定义SNN模型
from neurai.nn.module import Module
from neurai.nn.neuron import SNNLIF
from neurai.nn.layer import Linear, Flatten
from neurai.grads import surrogate_grad
import neurai.const as nc
class SNNMLP(Module):
def setup(self):
self.lif = SNNLIF(step_mode=nc.single, surrogate_grad_fn=surrogate_grad.Rectangular())
self.linear = Linear(10, bias=False)
self.flatten = Flatten()
def __call__(self, input):
y = self.flatten(input)
y = self.linear(y)
y = self.lif(y)
return y
定义损失函数、优化器函数和评估函数,例如:
model = SNNMLP()
ps = model.init(input=spike_data)
from neurai.opt import adam, apply_updates
from neurai.nn.loss import mse_loss
from neurai.grads import grad
opt = adam(learning_rate=0.001)
opt_state = opt.init(ps)
model = model.run
def loss_f(model_param, batch, label):
predict = 0.
for t in range(time_windows):
predict += model(model_param, batch[0][..., t])
predict = predict / time_windows
loss_value = mse_loss(predict, label)
return loss_value, predict
训练
from jax import jit
from jax import numpy as jnp, nn as jnn
@jit
def train_update(model_param, batch, opt_state):
"""
This is a update function for SNN model.
Parameters:
loss_fn (class method): Function for calculating model losses.
model_param (dict): Parameter Dictionary of the Model.
batch (tuple): Tuple unpacking e.g. input, desiredClass, label = batch.
"""
grad_fn = grad(loss_f, has_aux=True, return_fun_value=True)
label = jnn.one_hot(batch[-1], 10)
grads, loss_value, predict = grad_fn(model_param, batch, label)
accuracy = (jnp.argmax(predict, 1) == batch[-1]).sum()
updates, opt_state = opt.update(grads, opt_state, model_param)
ps = apply_updates(model_param, updates)
return grads, loss_value, accuracy, ps, opt_state
from tqdm import tqdm
Epochs = 5
for ep in range(Epochs):
train_acc = 0
train_loss = 0
trainnum = 0
train_loop = tqdm(train_loader, desc="Train")
train_loop.set_description(f"Epoch: [{ep+1} / {Epochs}]")
for st, batch in enumerate(train_loop):
grads, loss_value, accuracy, ps, opt_state = train_update(ps, batch, opt_state)
train_acc += accuracy
train_loss += loss_value
trainnum += len(batch[1])
train_loop.set_postfix(accuracy=train_acc / trainnum, loss=train_loss / trainnum)