如何构建混合网络#

作者: MamieZhu

本章将提供一个混合网络的简单示例,以探讨混合网络在应用场景和领域方面的潜在价值。

选择硬件平台#

首先使用 neurai.config.set_platform 来设置平台,可选的平台有 cpu, gpu, apu 。但目前apu还未支持脑仿真模型的训练和混合网络的部署。

from neurai.config import set_platform

set_platform(platform="gpu")

下载数据集#

训练和测试的数据集可以通过 neurai.datasets.mnist.MNIST 下载到参数 root 的路径下。

from neurai import datasets

train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)

加载数据集#

通过传递一些参数(例如 datasetbatch_sizeshuffle 等), 可以通过 neurai.datasets.dataloader.DataLoadertrain_datatest_data 提供一个迭代器。

train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)

定义网络结构#

定义ANN网络#

创建网络需要继承 neurai.nn.module.Module 类,然后通过 neurai.nn.layer.linear.Linearneurai.nn.layer.activate.Relu 等多个层的API构建网络结构,最后在 __call__ 中定义前向计算过程。

1. 定义一个名为 ‘MNISTModelPre’ 的网络

定义一个包含三个 Linear 的ANN网络,每一层都有 Relu 激活,用于对数据集的特征提取,示例代码如下:

from neurai import nn

class MNISTModelPre(nn.Module):

  def setup(self):
    self.fc1 = nn.Linear(50)
    self.fc2 = nn.Linear(60)
    self.fc3 = nn.Linear(50)
    self.relu = nn.Relu()

  def __call__(self, inputs):
    fc1_out = self.relu(self.fc1(input=input.reshape(-1, 28 * 28 * 1)))
    fc2_out = self.relu(self.fc2(fc1_out))
    fc3_out = self.fc3(fc2_out)
    return fc3_out

2. 定义一个名为 ‘MNISTModelPost’ 的网络

定义一个包含三个 Linear 的网络作为第二个ANN网络,同样每一层都有 Relu 激活,示例代码如下:

class MNISTModelPost(nn.Module):

  def setup(self):
    self.fc1 = nn.Linear(50)
    self.fc2 = nn.Linear(60)
    self.fc3 = nn.Linear(10)
    self.relu = Relu()

  def __call__(self, input):
    fc1_out = self.relu(self.fc1(input))
    fc2_out = self.relu(self.fc2(fc1_out))
    fc3_out = self.fc3(fc2_out)
    return fc3_out

定义脑仿真网络(简称BSN)#

创建脑模拟网络需要继承 neurai.nn.snet.SNet 类,然后通过 neurai.nn.neuron.ExpLIFneurai.nn.synapse.static_synapse.StaticSynapse 等API构建网络结构,并通过 neurai.nn.conn.connrule 定义神经元之间的连接。

值得注意的是,在混合网络中,需要专门的 InputTransmitter 来接收来自人工神经网络(ANN)的输出数据,作为脑模拟网络的输入。 用户可以通过 StaticSynapse 的连接,根据需求将ANN的输出数据传输给相应的神经元簇。

定义了一个包含两个神经元簇的脑仿真模型,ANN网络的输出结果作为输入连接到两个神经元簇上。示例代码如下:

from neurai.nn.snet import SNet
from neurai.nn.neuron import ExpLIF, InputTransmitter
from neurai.nn.conn.connrule import One2One
from neurai.nn.synapse.static_synapse import StaticSynapse

class SNetSimple(SNet):

  def setup(self):
    self.inputsgen = InputTransmitter(size=50, batch_first=False)
    self.pre_pop = ExpLIF(size=50, V_rest=0., V_th=1.0, V_reset=-60., tau=20., I_e=0., v_init=0.)
    self.post_pop = ExpLIF(size=50, V_rest=0., V_th=1.0, V_reset=-60., tau=20., I_e=0., v_init=0.)
    self.pre_conn = StaticSynapse(self.inputsgen, self.pre_pop, conn=One2One(), weight=1)
    self.post_conn = StaticSynapse(self.inputsgen, self.post_pop, conn=One2One(), weight=1)
    self.pop_conn = StaticSynapse(self.pre_pop, self.post_pop, conn=One2One(), weight=0.98)

构建混合网络#

基于上述定义的网络组件,我们可以构建一个ANN-BSN-ANN混合网络,这种网络结构结合了人工神经网络(ANN)和生物脉冲神经网络(BSN)的特点。在这种混合网络中,我们通过在 __call__ 方法中定义的前向计算过程来连接各个网络。

Note

  • ANN的输出是一个二维数组,而BSN的输入需要一维数组。这意味着在将ANN的输出传递给BSN之前,需要进行数据的转换。

  • BSN的输出是一个字典形式的数据,为了正确地从BSN获取所需的输出,我们需要通过指定一个 output 参数来定义期望的输出格式。

这个 output 参数是一个字典,它的键指定了输出神经元簇的名称,值是对应神经元簇的输出数据名称组成的列表,可指定 spikev

为了实现ANN和BSN之间的数据转换,我们设计了一个名为 InputTransmitter 的组件。这个组件负责将ANN的输出转换为BSN所需的输入格式, BSN的输出是一个字典形式的数据,需要根据需求可以进行取值、切片、函数计算等,将输出转换回ANN所需的格式。 以下是实现这一数据转换和网络连接过程的示例代码:

from neurai import nn
from neurai.nn.snet import SNetLayer


class Network(nn.Module):
  def setup(self):
    self.pre_net = MNISTModelPre()
    self.post_net = MNISTModelPost()
    self.sim_net = SNetLayer(SNetSimple)

  def __call__(self, input, t=0.5):
    input_data = input.reshape(-1, 28 * 28 * 1)
    ann_predict_1 = self.pre_net(input_data)
    sim_output, _ = self.sim_net(input=ann_predict_1[0], t=t, output={'ExpLIF_1':['spike']})
    ann_predict_2 = self.post_net(input=jnp.sum(sim_output['ExpLIF_1']['spike'], axis=0).reshape(1, 50))
    return ann_predict_2

  def accuracy(self, predict, target):
    return jnp.mean(jnp.argmax(predict, axis=1) == target)

优化器和梯度下降#

通过 neurai.opt 创建优化器。

from neurai.opt import adam, apply_updates

hybrid_opt = adam(0.001)

通过 neurai.nn.loss 创建损失计算函数。

loss_fun = loss.softmax_cross_entropy(out, n_label)

初始化模型#

在进行训练和测试之前,需要实例化已创建的模型,然后初始化该模型。

hybrid_model = Network()
params = hybrid_model.init(jnp.ones([1, 28 * 28 * 1]))
opt_ann_state = hybrid_opt.init(hybrid_params['param'])

定义模型计算函数#

由于数据集具有batch的维度,而目前脑仿真模型的训练是单batch形式,因此这里可以使用vmap实现模型的计算。

@jit
def model_run(ann_params, sim_params, data_batch):
  x, y = data_batch
  hybrid_parameters = {"param": ann_params, "frozen_param": sim_params}
  x_buffer = jax.vmap(lambda y: hybrid_model.run(hybrid_parameters, input=y))(x)
  model_output = jnp.squeeze(jnp.asarray(x_buffer))
  model_loss = softmax_cross_entropy(model_output, jnn.one_hot(y, num_classes=10))
  return model_loss, model_output

定义训练过程#

loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
ann_params = hybrid_params["param"]
sim_params = hybrid_params["frozen_param"]
grads_value, loss_value, model_output = loss_grad(ann_params, sim_params, (jnp.asarray(x_train), jnp.asarray(y_train)))
updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, hybrid_params["param"])
hybrid_params["param"] = apply_updates(hybrid_params["param"], updates)

定义测试过程#

ann_params = hybrid_params["param"]
sim_params = hybrid_params["frozen_param"]
loss_value, model_output = model_run(ann_params, sim_params, (jnp.asarray(x_test), jnp.asarray(y_test)))

完整示例代码#

import sys
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from neurai.config import set_platform
set_platform(platform="gpu")
import jax.numpy as jnp
import jax.nn as jnn
import jax
from tqdm import tqdm
from neurai import nn
from neurai.opt import adam, apply_updates
from neurai.util import jit
from neurai.nn.neuron import ExpLIF, InputTransmitter
from neurai.grads.autograd import grad
from neurai.nn.snet import SNet, SNetLayer
from neurai.datasets import MNIST, DataLoader
from neurai.nn.layer.activate import Relu
from neurai.nn.loss import softmax_cross_entropy
from neurai.nn.synapse.static_synapse import StaticSynapse
from neurai.nn.conn.connrule import One2One
from neurai.util import serialization
from neurai.setting import DATASETS_DIR

batch_size = 32
train_data = MNIST(DATASETS_DIR, download=True, train=True)
test_data = MNIST(DATASETS_DIR, download=True, train=False)

train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, drop_last=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, drop_last=True)


# ANN model.
class MNISTModelPre(nn.Module):

  def setup(self):
    self.fc1 = nn.Linear(50)
    self.fc2 = nn.Linear(60)
    self.fc3 = nn.Linear(50)
    self.relu = Relu()

  def __call__(self, input):
    fc1_out = self.relu(self.fc1(input=input.reshape(-1, 28 * 28 * 1)))
    fc2_out = self.relu(self.fc2(fc1_out))
    fc3_out = self.fc3(fc2_out)
    return fc3_out


# ANN model.
class MNISTModelPost(nn.Module):

  def setup(self):
    self.fc1 = nn.Linear(50)
    self.fc2 = nn.Linear(60)
    self.fc3 = nn.Linear(10)
    self.relu = Relu()

  def __call__(self, input):
    fc1_out = self.relu(self.fc1(input))
    fc2_out = self.relu(self.fc2(fc1_out))
    fc3_out = self.fc3(fc2_out)
    return fc3_out


# Brain simulation.
class SNetSimple(SNet):

  def setup(self):
    self.inputsgen = InputTransmitter(size=50, batch_first=False)
    self.pre_pop = ExpLIF(size=50, V_rest=0., V_th=1.0, V_reset=-60., tau=20., I_e=0., v_init=0.)
    self.post_pop = ExpLIF(size=50, V_rest=0., V_th=1.0, V_reset=-60., tau=20., I_e=0., v_init=0.)
    self.pre_conn = StaticSynapse(self.inputsgen, self.pre_pop, conn=One2One(), weight=1)
    self.post_conn = StaticSynapse(self.inputsgen, self.post_pop, conn=One2One(), weight=1)
    self.pop_conn = StaticSynapse(self.pre_pop, self.post_pop, conn=One2One(), weight=0.98)


# hybrid model.
class Network(nn.Module):

  def setup(self):
    self.pre_net = MNISTModelPre()
    self.post_net = MNISTModelPost()
    self.sim_net = SNetLayer(SNetSimple)

  def __call__(self, input, t=0.5):
    input_data = input.reshape(-1, 28 * 28 * 1)
    ann_predict_1 = self.pre_net(input_data)
    sim_output, _ = self.sim_net(input=ann_predict_1[0], t=t, output={'ExpLIF_1':['spike']})
    ann_predict_2 = self.post_net(input=jnp.sum(sim_output['ExpLIF_1']['spike'], axis=0).reshape(1, 50))
    return ann_predict_2

  def accuracy(self, predict, target):
    return jnp.mean(jnp.argmax(predict, axis=1) == target)


def copy_fn(x):
  return jnp.asarray(x)


def create_folder(folder_path):
  if not os.path.exists(folder_path):
    os.makedirs(folder_path)


@jit
def model_run(ann_params, sim_params, data_batch):
  x, y = data_batch
  hybrid_parameters = {"param": ann_params, "frozen_param": sim_params}
  x_buffer = jax.vmap(lambda y: hybrid_model.run(hybrid_parameters, input=y))(x)
  model_output = jnp.squeeze(jnp.asarray(x_buffer))
  model_loss = softmax_cross_entropy(model_output, jnn.one_hot(y, num_classes=10))
  return model_loss, model_output


if __name__ == "__main__":
  hybrid_opt = adam(0.001)
  hybrid_model = Network()
  hybrid_params = hybrid_model.init(input=jnp.ones([1, 28 * 28 * 1]))
  opt_ann_state = hybrid_opt.init(hybrid_params["param"])

  Epochs = 10
  train_ann_accdata = []
  train_ann_lossdata = []
  train_model_accdata = []
  train_model_lossdata = []
  test_ann_accdata = []
  test_ann_lossdata = []
  test_model_accdata = []
  test_model_lossdata = []
  best_score = 0.

  for epoch in range(Epochs):
    # train part.
    train_model_acc = 0.0
    train_model_ls = 0.0

    test_model_acc = 0.0
    test_model_ls = 0.0

    train_loop = tqdm(train_dataloader, desc="Train")
    train_loop.set_description("weight update")

    for batch_id, (x_train, y_train) in enumerate(train_loop):
      loss_grad = grad(model_run, has_aux=True, return_fun_value=True, allow_int=True)
      ann_params = hybrid_params["param"]
      sim_params = hybrid_params["frozen_param"]
      grads_value, loss_value, model_output = loss_grad(ann_params, sim_params, (jnp.asarray(x_train), jnp.asarray(y_train)))
      updates, opt_ann_state = hybrid_opt.update(grads_value, opt_ann_state, hybrid_params["param"])
      hybrid_params["param"] = apply_updates(hybrid_params["param"], updates)
      model_acc = hybrid_model.accuracy(model_output, y_train)
      train_model_acc += model_acc
      train_model_ls += loss_value

      if model_acc > best_score:
        best_score = model_acc
        serialization.save(path="./model", param=hybrid_params, overwrite=True)
    print("Train acc:{}".format(train_model_acc/batch_id), "loss:{}".format(train_model_ls/batch_id))

    test_loop = tqdm(test_dataloader, desc="Test")
    for batch_id, (x_test, y_test) in enumerate(test_loop):
      ann_params = hybrid_params["param"]
      sim_params = hybrid_params["frozen_param"]
      loss_value, model_output = model_run(ann_params, sim_params, (jnp.asarray(x_test), jnp.asarray(y_test)))
      model_acc = hybrid_model.accuracy(model_output, y_test)
      test_model_acc += model_acc
      test_model_ls += loss_value

      if model_acc > best_score:
        best_score = model_acc
    print("Test acc:{}".format(test_model_acc/batch_id), "loss:{}".format(test_model_ls/batch_id))