脑仿真支持可学习训练#

作者: MamieZhu

本文主要介绍如何搭建一个脑仿真网络,并且能在分类任务上进行学习训练, 这里的脑仿真网络指的是使用静态突触连接的网络。

定义网络#

首先使用 neurai.config.set_platform 来设置平台,可选的平台有 cpu, gpu, apu ,但当前版本 apu 不支持脑仿真学习。 然后通过创建 neurai.nn.snet.SNet 的子类来定义一个网络:

from neurai.config import set_platform
set_platform(platform='gpu')
from neurai import nn
from neurai.initializer import UniformIniter
import jax.numpy as jnp

class SimNet(nn.SNet):

  def setup(self):
    self.inputsgen = nn.InputTransmitter(size=784, batch_first=False)
    self.pop0 = nn.ExpLIF(size=784, V_rest=0., V_th=0.4, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
    self.pop1 = nn.ExpLIF(size=10, V_rest=0., V_th=3, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)

    self.pop0_conn = nn.StaticSynapse(self.inputsgen, self.pop0, conn=nn.One2One(), weight=UniformIniter())
    self.pop_conn = nn.StaticSynapse(self.pop0, self.pop1, conn=nn.FixedProb(1), weight=UniformIniter())

class Network(nn.Module):
  def setup(self):
    self.net = nn.SNetLayer(SimNet)

  def __call__(self, input=None, t=10):
    output, _ = self.net(input, t, output={"ExpLIF_1":["spike"]})
    output = jnp.sum(output["ExpLIF_1"]["spike"], axis=0)
    return output

  def accuracy(self, predict, target):
    return jnp.mean(jnp.argmax(predict, axis=1) == target)

以上网络中,在 setup 函数中写入了神经元和突触的定义。 NeurAI 框架支持多种神经元模型,具体可参考 神经元模型, 也支持多种突触模型,具体可参考 突触模型

在构建可学习的脑仿真网络时,需要定义一个 InputTransmitter 神经元用来接收数据集的数据输入, 通过 InputTransmitter 传递到其它神经元簇。

并且,在脑仿真模型中需要传入 output 指定输出数据, output 是字典格式:字典的键是指定输出的神经元簇的名称, 如果在定义神经元簇时有输入自定义 name, 以自定义名称来,如果是默认的,则以类名+”_序号”为神经元簇的名称; 字典的值是指定脉冲或者电压值作为输出,值是列表的形式,列表中的值可选 spikev

脑仿真的输出是两个量:第一个值指定输出结果,是一个字典,根据指定的 output 输出对应t时间长度内的脉冲或者电压; 第二个值是指定的监视器的值,如果在模型中有监视某些中间变量,可以输出这个值,没有的话可不做处理。

下载数据集#

训练和测试的数据集可以通过 neurai.datasets.mnist.MNIST 下载到参数 root 的路径下。

from neurai import datasets

train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)

加载数据集#

通过传递一些参数(例如 datasetbatch_sizeshuffle 等), 可以通过 neurai.datasets.dataloader.DataLoadertrain_datatest_data 提供一个迭代器。

train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)

初始化网络#

使用 init 函数初始化模型。

from neurai.opt import adam, apply_updates

net = Network()
param = net.init(input=jnp.ones((784, ), jnp.float32))

hybrid_opt = adam(0.01)
opt_ann_state = hybrid_opt.init(param['param'])

在使用脑仿真网络进行学习训练过程,如果仿真时长不在模型定义中指定,请在模型初始化和运行函数中都传入一致的参数 t

定义前向计算函数#

from neurai.datasets.transforms import PoissonEncoder
from neurai.nn.loss import softmax_cross_entropy
import jax

encoder = PoissonEncoder()

def model_run(param_p, param_f, x_data, y_data):
  param = {"param": param_p, "frozen_param": param_f}
  x_buffer = jax.vmap(lambda x: net.run(param, input=encoder(jnp.tile(jnp.array(x.flatten(), jnp.float32),(100,1)))))(x_data/255.)
  model_output = jnp.squeeze(jnp.asarray(x_buffer))
  model_loss = softmax_cross_entropy(model_output, jax.nn.one_hot(y_data, num_classes=10))
  return model_loss, model_output

这里有两个特殊的处理:一个是对输入数据做了数据扩充,将数据扩展到t时间内每一步都有输入,并且做了 encoder 处理,这是根据需求可选操作; 另一个是对模型运行使用 vmap 的操作,针对数据集的 batch 进行 vmap 并行,提高运行性能。

定义模型训练函数#

from neurai.grads.autograd import grad

@jax.jit
def train(param, x_data, y_data):
  param_p = param['param']
  param_f = param["frozen_param"]
  loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
  grads_value, loss_value, model_output = loss_grad(param_p, param_f, x_data, y_data)
  return grads_value, loss_value, model_output

模型训练#

grads_value, loss_value, model_output = train(param, x_train, y_train)
updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, param["param"])
param["param"] = apply_updates(param["param"], updates)
model_acc = net.accuracy(model_output, y_train)

训练过程对参数进行更新,并输出损失和精度。

模型测试#

param_p = param['param']
param_f = param["frozen_param"]
loss_value, model_output = model_run(param_p, param_f, x_test, y_test)
model_acc = net.accuracy(model_output, y_test)

完整示例代码#

from neurai.config import set_platform

set_platform(platform='gpu')
from neurai import nn
from neurai.initializer import UniformIniter
import jax.numpy as jnp
from neurai.util import serialization
from neurai import datasets
from neurai.opt import adam, apply_updates
from neurai.datasets.transforms import PoissonEncoder
from neurai.nn.loss import softmax_cross_entropy
from neurai.grads.autograd import grad
import jax
from tqdm import tqdm


class SimNet(nn.SNet):

  def setup(self):
    self.inputsgen = nn.InputTransmitter(size=784, batch_first=False)
    self.pop0 = nn.ExpLIF(size=784, V_rest=0., V_th=0.4, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
    self.pop1 = nn.ExpLIF(size=10, V_rest=0., V_th=3, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)

    self.pop0_conn = nn.StaticSynapse(self.inputsgen, self.pop0, conn=nn.One2One(), weight=UniformIniter())
    self.pop_conn = nn.StaticSynapse(self.pop0, self.pop1, conn=nn.FixedProb(1), weight=UniformIniter())


class Network(nn.Module):

  def setup(self):
    self.net = nn.SNetLayer(SimNet)

  def __call__(self, input=None, t=10):
    output, _ = self.net(input, t, output={"ExpLIF_1": ["spike"]})
    output = jnp.sum(output["ExpLIF_1"]["spike"], axis=0)
    return output

  def accuracy(self, predict, target):
    return jnp.mean(jnp.argmax(predict, axis=1) == target)


train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)

net = Network()
param = net.init(input=jnp.ones((784,), jnp.float32))
hybrid_opt = adam(0.01)
opt_ann_state = hybrid_opt.init(param['param'])

encoder = PoissonEncoder()


def model_run(param_p, param_f, x_data, y_data):
  param = {"param": param_p, "frozen_param": param_f}
  x_buffer = jax.vmap(lambda x: net.run(param, input=encoder(jnp.tile(jnp.array(x.flatten(), jnp.float32), (100, 1)))))(
    x_data / 255.)
  model_output = jnp.squeeze(jnp.asarray(x_buffer))
  model_loss = softmax_cross_entropy(model_output, jax.nn.one_hot(y_data, num_classes=10))
  return model_loss, model_output


@jax.jit
def train(param, x_data, y_data):
  param_p = param['param']
  param_f = param["frozen_param"]
  loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
  grads_value, loss_value, model_output = loss_grad(param_p, param_f, x_data, y_data)
  return grads_value, loss_value, model_output


Epochs = 100
batch_size = 32

for epoch in range(Epochs):
  print(f"{epoch=}")
  train_loop = tqdm(train_loader, desc="Train")
  train_loop.set_description("weight update")
  sum_model_acc = 0
  sum_loss_value = 0

  for batch_id, (x_train, y_train) in enumerate(train_loop):
    grads_value, loss_value, model_output = train(param, x_train, y_train)
    updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, param["param"])
    param["param"] = apply_updates(param["param"], updates)
    model_acc = net.accuracy(model_output, y_train)
    sum_model_acc += model_acc
    sum_loss_value += loss_value
  print(f"training: loss_value = {(sum_loss_value/(batch_id+1))}, model_acc = {(sum_model_acc/(batch_id+1))}")
  serialization.save(path="./model" + str(epoch), param=param["param"], overwrite=True)

  test_loop = tqdm(test_loader, desc="Test")
  sum_model_acc = 0
  sum_loss_value = 0
  for batch_id, (x_test, y_test) in enumerate(test_loop):
    param_p = param['param']
    param_f = param["frozen_param"]
    loss_value, model_output = model_run(param_p, param_f, x_test, y_test)
    model_acc = net.accuracy(model_output, y_test)
    sum_model_acc += model_acc
    sum_loss_value += loss_value
  print(f"testing: loss_value = {(sum_loss_value/(batch_id+1))}, model_acc = {(sum_model_acc/(batch_id+1))}")