脑仿真支持可学习训练#
作者: MamieZhu
本文主要介绍如何搭建一个脑仿真网络,并且能在分类任务上进行学习训练, 这里的脑仿真网络指的是使用静态突触连接的网络。
定义网络#
首先使用 neurai.config.set_platform
来设置平台,可选的平台有 cpu, gpu, apu ,但当前版本 apu 不支持脑仿真学习。
然后通过创建 neurai.nn.snet.SNet
的子类来定义一个网络:
from neurai.config import set_platform
set_platform(platform='gpu')
from neurai import nn
from neurai.initializer import UniformIniter
import jax.numpy as jnp
class SimNet(nn.SNet):
def setup(self):
self.inputsgen = nn.InputTransmitter(size=784, batch_first=False)
self.pop0 = nn.ExpLIF(size=784, V_rest=0., V_th=0.4, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
self.pop1 = nn.ExpLIF(size=10, V_rest=0., V_th=3, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
self.pop0_conn = nn.StaticSynapse(self.inputsgen, self.pop0, conn=nn.One2One(), weight=UniformIniter())
self.pop_conn = nn.StaticSynapse(self.pop0, self.pop1, conn=nn.FixedProb(1), weight=UniformIniter())
class Network(nn.Module):
def setup(self):
self.net = nn.SNetLayer(SimNet)
def __call__(self, input=None, t=10):
output, _ = self.net(input, t, output={"ExpLIF_1":["spike"]})
output = jnp.sum(output["ExpLIF_1"]["spike"], axis=0)
return output
def accuracy(self, predict, target):
return jnp.mean(jnp.argmax(predict, axis=1) == target)
以上网络中,在 setup 函数中写入了神经元和突触的定义。
NeurAI
框架支持多种神经元模型,具体可参考 神经元模型, 也支持多种突触模型,具体可参考 突触模型。
在构建可学习的脑仿真网络时,需要定义一个 InputTransmitter
神经元用来接收数据集的数据输入,
通过 InputTransmitter
传递到其它神经元簇。
并且,在脑仿真模型中需要传入 output
指定输出数据, output
是字典格式:字典的键是指定输出的神经元簇的名称,
如果在定义神经元簇时有输入自定义 name
, 以自定义名称来,如果是默认的,则以类名+”_序号”为神经元簇的名称;
字典的值是指定脉冲或者电压值作为输出,值是列表的形式,列表中的值可选 spike
和 v
。
脑仿真的输出是两个量:第一个值指定输出结果,是一个字典,根据指定的 output
输出对应t时间长度内的脉冲或者电压;
第二个值是指定的监视器的值,如果在模型中有监视某些中间变量,可以输出这个值,没有的话可不做处理。
下载数据集#
训练和测试的数据集可以通过 neurai.datasets.mnist.MNIST
下载到参数 root
的路径下。
from neurai import datasets
train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
加载数据集#
通过传递一些参数(例如 dataset
、 batch_size
和 shuffle
等),
可以通过 neurai.datasets.dataloader.DataLoader
为 train_data
和 test_data
提供一个迭代器。
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)
初始化网络#
使用 init
函数初始化模型。
from neurai.opt import adam, apply_updates
net = Network()
param = net.init(input=jnp.ones((784, ), jnp.float32))
hybrid_opt = adam(0.01)
opt_ann_state = hybrid_opt.init(param['param'])
在使用脑仿真网络进行学习训练过程,如果仿真时长不在模型定义中指定,请在模型初始化和运行函数中都传入一致的参数 t
。
定义前向计算函数#
from neurai.datasets.transforms import PoissonEncoder
from neurai.nn.loss import softmax_cross_entropy
import jax
encoder = PoissonEncoder()
def model_run(param_p, param_f, x_data, y_data):
param = {"param": param_p, "frozen_param": param_f}
x_buffer = jax.vmap(lambda x: net.run(param, input=encoder(jnp.tile(jnp.array(x.flatten(), jnp.float32),(100,1)))))(x_data/255.)
model_output = jnp.squeeze(jnp.asarray(x_buffer))
model_loss = softmax_cross_entropy(model_output, jax.nn.one_hot(y_data, num_classes=10))
return model_loss, model_output
这里有两个特殊的处理:一个是对输入数据做了数据扩充,将数据扩展到t时间内每一步都有输入,并且做了 encoder
处理,这是根据需求可选操作;
另一个是对模型运行使用 vmap
的操作,针对数据集的 batch
进行 vmap
并行,提高运行性能。
定义模型训练函数#
from neurai.grads.autograd import grad
@jax.jit
def train(param, x_data, y_data):
param_p = param['param']
param_f = param["frozen_param"]
loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
grads_value, loss_value, model_output = loss_grad(param_p, param_f, x_data, y_data)
return grads_value, loss_value, model_output
模型训练#
grads_value, loss_value, model_output = train(param, x_train, y_train)
updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, param["param"])
param["param"] = apply_updates(param["param"], updates)
model_acc = net.accuracy(model_output, y_train)
训练过程对参数进行更新,并输出损失和精度。
模型测试#
param_p = param['param']
param_f = param["frozen_param"]
loss_value, model_output = model_run(param_p, param_f, x_test, y_test)
model_acc = net.accuracy(model_output, y_test)
完整示例代码#
from neurai.config import set_platform
set_platform(platform='gpu')
from neurai import nn
from neurai.initializer import UniformIniter
import jax.numpy as jnp
from neurai.util import serialization
from neurai import datasets
from neurai.opt import adam, apply_updates
from neurai.datasets.transforms import PoissonEncoder
from neurai.nn.loss import softmax_cross_entropy
from neurai.grads.autograd import grad
import jax
from tqdm import tqdm
class SimNet(nn.SNet):
def setup(self):
self.inputsgen = nn.InputTransmitter(size=784, batch_first=False)
self.pop0 = nn.ExpLIF(size=784, V_rest=0., V_th=0.4, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
self.pop1 = nn.ExpLIF(size=10, V_rest=0., V_th=3, V_reset=-0., tau=0.1, I_e=0., v_init=0., R=0.08)
self.pop0_conn = nn.StaticSynapse(self.inputsgen, self.pop0, conn=nn.One2One(), weight=UniformIniter())
self.pop_conn = nn.StaticSynapse(self.pop0, self.pop1, conn=nn.FixedProb(1), weight=UniformIniter())
class Network(nn.Module):
def setup(self):
self.net = nn.SNetLayer(SimNet)
def __call__(self, input=None, t=10):
output, _ = self.net(input, t, output={"ExpLIF_1": ["spike"]})
output = jnp.sum(output["ExpLIF_1"]["spike"], axis=0)
return output
def accuracy(self, predict, target):
return jnp.mean(jnp.argmax(predict, axis=1) == target)
train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)
net = Network()
param = net.init(input=jnp.ones((784,), jnp.float32))
hybrid_opt = adam(0.01)
opt_ann_state = hybrid_opt.init(param['param'])
encoder = PoissonEncoder()
def model_run(param_p, param_f, x_data, y_data):
param = {"param": param_p, "frozen_param": param_f}
x_buffer = jax.vmap(lambda x: net.run(param, input=encoder(jnp.tile(jnp.array(x.flatten(), jnp.float32), (100, 1)))))(
x_data / 255.)
model_output = jnp.squeeze(jnp.asarray(x_buffer))
model_loss = softmax_cross_entropy(model_output, jax.nn.one_hot(y_data, num_classes=10))
return model_loss, model_output
@jax.jit
def train(param, x_data, y_data):
param_p = param['param']
param_f = param["frozen_param"]
loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
grads_value, loss_value, model_output = loss_grad(param_p, param_f, x_data, y_data)
return grads_value, loss_value, model_output
Epochs = 100
batch_size = 32
for epoch in range(Epochs):
print(f"{epoch=}")
train_loop = tqdm(train_loader, desc="Train")
train_loop.set_description("weight update")
sum_model_acc = 0
sum_loss_value = 0
for batch_id, (x_train, y_train) in enumerate(train_loop):
grads_value, loss_value, model_output = train(param, x_train, y_train)
updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, param["param"])
param["param"] = apply_updates(param["param"], updates)
model_acc = net.accuracy(model_output, y_train)
sum_model_acc += model_acc
sum_loss_value += loss_value
print(f"training: loss_value = {(sum_loss_value/(batch_id+1))}, model_acc = {(sum_model_acc/(batch_id+1))}")
serialization.save(path="./model" + str(epoch), param=param["param"], overwrite=True)
test_loop = tqdm(test_loader, desc="Test")
sum_model_acc = 0
sum_loss_value = 0
for batch_id, (x_test, y_test) in enumerate(test_loop):
param_p = param['param']
param_f = param["frozen_param"]
loss_value, model_output = model_run(param_p, param_f, x_test, y_test)
model_acc = net.accuracy(model_output, y_test)
sum_model_acc += model_acc
sum_loss_value += loss_value
print(f"testing: loss_value = {(sum_loss_value/(batch_id+1))}, model_acc = {(sum_model_acc/(batch_id+1))}")