局部学习#

作者: LuckyHFC & HaiboChen

基于能量的模型是一类具有生物合理性和鲁棒性的神经网络模型,预测编码(Predictive Coding, PC)和均衡传播(Equilibrium Propagation, EP)就是两种代表性的学习方法,通过构建能量函数描述神经网络行为,在解释生物神经系统的信息处理机制方面等方面具有天然优势,神经科学和机器学习领域具有广阔的研究和应用前景。

然而,目前还没任何一款框架能够友好的支持该类研究。研究者需要在网络构建、状态与参数更新等方面需要花费大量的时间。为了能够有效支持与推动该方面的研究,本文将从以下几个方面展开局部学习的介绍。

局部学习介绍#

反向传播算法是深度学习成功的基础,但它需要连续的向后更新和非局部计算,这使得大规模并行化任务具有挑战性,且学习规则与大脑的学习方式不同,不具备生物合理性。以PC和EP为代表的具有神经科学启发的学习算法-局部学习,相对于反向传播算法在异步更新、鲁棒性和联想记忆方面具有巨大优势。

局部学习是一种利用赫布(Hebbian)规则训练神经网络的具有生物合理性的学习算法,其根植于神经科学,被广泛用于特征抽象,促进了许多高级脑启发学习和记忆机制的实现。局部学习算法调整突触权重仅依赖于局部变量,学习规则如下:

\[\triangle w_i = F(v_{i}, v_{i-1}, w_i)\]

考虑具有 \(L+1\) 前馈结构 \(N_0,N_1,…,N_L\) 的深度学习架构 ( 图1. 深度局部学习网络示意图 ),其中 \(N_0\) 是网络输入层, \(N_L\) 是网络输出层。设 \(v_i\) 表示第 \(N_i\) 层的激活值,\(v_i=g(v_{i-1}, w_i)\) ,其中激活函数可以是任意的。我们考虑有监督学习框架,训练集的输入输出表示为向量对的形式 \((v_0, v_L)\) ,最小化目标函数为 \(F\)

Deep local learning

图1. 深度局部学习网络示意图#

图1. 深度局部学习网络示意图 所示,为深度局部学习流程,对于每个网络层都使用局部学习规则进行学习,对于隐藏层而言,学习规则都是无监督的,隐藏层学习规则形式化表示为 \(\triangle w_i=F(v_i, v_{i-1}, w_i)\) ,对于输出层,学习规则是有监督的,因为数据标签被认为是局部变量,输出层学习规则形式化表示为 \(\triangle w_L=F(v_L, v_{L-1}, w_L)\)

基于能量的局部学习#

局部学习算法有许多,如Spike-timing Dependent Plasticity(STDP),Hopfield network,Predictive Coding(PC),Equilibrium Propagation(EP),Contrastive Hebbian Learning (CHL)等。其中基于能量的局部学习算法(Energy Based Local Learning, EBLL)是最具完备理论和性能最突出的一个分支,它包括PC,EP,CHL等。该理论的核心是任意系统都由能量所维持,其中高能量意味着更混乱的系统,低能量意味着更稳定和更有秩序的系统。将系统的能量E作为目标函数,系统通过感知来最小化能量,将一个混乱的系统转换为一个稳定的系统,以完成外部感知的学习。EBLL的能量由模型神经元状态和外部损失函数构成,并且通过顺序执行推理和学习两个阶段来训练神经网络。

推理阶段:

对于任意可微的能量形式,通过梯度下降来使能量最小化。在能量最小时,获得模型最稳定的状态,形式化表示为:

\[dv_i = -\frac{{\partial E}}{{\partial v_i}}\]

学习阶段:

在能量最小时,它准确指向监督损失梯度方向,提取此时能量的瞬时方向用于参数更新,参数更新形式化表示为:

\[d\theta_i = -\frac{{\partial E}}{{\partial \theta_i}}\]

预测编码#

预测编码是一种用于深度神经网络的学习框架,它是基于自由能的局部学习的代表方法,其灵感来源于认知神经科学中的观察和理论。它旨在通过模拟大脑中神经元之间的信息传递和预测机制,实现对输入数据的建模和推理。其核心思想是通过生成预测数据(网络状态值)来解释输入数据,并根据输入数据与预测数据之间的差异(预测误差)不断进行修正和学习,使网络能够根据输入数据的统计特性和模式进行自适应学习,PC的能量模型由所有层的预测误差组成,通过最小化预测误差达到稳定的状态。其能量模型形式化表示为:

\[E = \sum_{i=1}^{L} {\epsilon}_{i}^2 + C\]

其中 \(\epsilon_i = v_i - g_i(v_{i-1}; \theta_i)\) 代表第 \(i^{th}\) 层的真实状态和的从第 \(i-1^{th}\) 层感知到的预测状态之间的差异, \(C\) 为输出层的损失函数。

预测编码的工作流程如下:

  • 步骤1 初始化网络状态:

    基于输入数据进行前向传播,初始化网络状态值,其中前向过程形式化表示为:

    \[v_i = g_i(v_{i-1}, w_i)\]

    其中 \(g_i(,)\) 表示第 \({i}\) 层网络的前向计算过程, \(v_{i-1}\) 为当前层的输入,\(w_i\) 为当前层的模型参数。

  • 步骤2 计算预测误差:

    根据先前的网络状态值与实际输入计算预测误差,预测误差形式化表示为:

    \[\epsilon_i = v_i - g_i(v_{i-1}, w_i)\]

    其中 \(v_i\) 为先前网络状态值, \(g_i(v_{i-1}, w_i)\) 为第 \(i\) 层的网络状态值。

  • 步骤3 更新网络状态:

    通过反向传播算法,我们可以不断地利用预测误差来修正网络的先前状态值。具体来说,网络状态值的更新过程可以表示为:

    \[v_i = v_i - \eta_vdv_i\]

    其中 \(\eta_v\) 表示状态值更新率,\(dv_i=-\frac{{\partial E}}{{\partial v_i}}=\epsilon_i-\epsilon_{i+1}\frac{{\partial g_{i+1}(v_i, w_{i+1})}}{\partial v_i}\) 表示状态值的变化量。

  • 步骤4 最小化能量和更新模型参数:

    不断迭代步骤 步骤1步骤2步骤3 使得能量值达到最小,利用最终的网络状态值及预测误差更新模型参数,模型参数更新形式化表示为:

    \[w_i = w_i - \eta_wdw_i\]

    其中 \(\eta_w\) 表示模型参数的更新率, \(dw_i=-\frac{{\partial E}}{{\partial w_i}}=\epsilon_i-\epsilon_{i+1}\frac{{\partial g_{i+1}(v_i, w_{i+1})}}{\partial w_{i+1}}\) 表示模型参数的梯度。

均衡传播#

均衡传播是一种基于Hopfield能量模型的学习框架,它的灵感来源于大脑中的平衡态理论,旨在模拟大脑中的信息传递和动态平衡的过程,它通过迭代调整和平衡神经元状态,逐步使得能量函数的动力学系统由最开始的不稳定状态逐渐向稳定状态变化,最终达到稳定状态的过程。其能量模型形式化表示为:

\[F = E + \beta C\]

其中 \(F=\frac{1}{2}\sum_{i} \Vert v_i \Vert_2 - \frac{1}{2}\sum_{ij,i \neq j} w_{ij}\rho(v_i)\rho(v_j) - \sum_{i} b_i\rho(v_i)\) 是内部能量,\(v_i\) 为神经元状态值,\(w_{ij}\)\(b_i\) 分别为神经元 \(i\) 和神经元 \(j\) 之间的模型权重和偏置, \(\beta\) 是控制外部能量(损失函数) \(C\) 的驱动强度。

均衡传播两阶段工作流程如下:

  • 步骤1 初始化网络状态:

    基于输入数据进行前向传播,初始化网络状态值,其中前向过程形式化表示为:

    \[v_i = g_i(v_{i-1}, w_i)\]

    其中 \(g_i(,)\) 表示第 \({i}\) 层网络的前向计算过程, \(v_{i-1}\) 为当前层的输入,\(w_i\) 为当前层的模型参数。

  • 步骤2 更新模型状态并达到第一阶段稳态:

    \(\beta=0\) ,模型只考虑内部能量,忽视外部能量的影响,不断更新网络状态,使得模型收敛到平衡点,得到稳定状态 \(s_0^*=[v_0^0, v_1^0, …, v_L^0]\) ,其中 \(L\) 为网络的层数,网络状态更新形式化表示为:

    \[v_i = v_i - \eta_v * dv_i\]

    其中 \(\eta_v\) 表示状态值的更新率,\(dv_i=-\frac{{\partial F}}{{\partial v_i}}=\rho^{'}(v_i)(\sum_(i \neq j)w_{ij} \rho(v_j) + b_i) - v_i - \beta C\) 表示状态值的变化量。

  • 步骤3 计算第一阶段模型梯度:

    利用第一阶段稳定状态 \(s_0^*\) 来计算模型参数梯度,其形式化表示为:

    \[dw_{i0}^{*} = -\frac{{\partial F}}{{\partial w_i}} = \rho(v_i^0)\rho(v_j^0)\]

    其中 \(\rho(,)\) 为激活函数, \(v_i\)\(v_j\) 为相邻的两个状态值。

  • 步骤4 更新模型状态达到第二阶段稳态:

    \(s_0^*\) 作为第二阶段的初始化状态,设 \(\beta \neq 0\) ,再次执行 步骤2步骤3 步骤得到新的稳定状态 \(s_{\beta}^* = [v_0^*, v_1^*, v_L^*]\) 和模型参数梯度 \(dw_{i\beta}^*\)

  • 步骤5 计算最终的参数梯度:

    根据两个阶段得到的参数梯度 \(dw_{i\beta}^*\)\(dw_{i0}^*\) 来计算最终的参数梯度, \(dw_i\) 形式化表示为:

    \[dw_i = \frac{1}{\beta}(dw_{i\beta}^* - dw_{i0}^*) = \frac{1}{\beta}(\rho(v_i^*)\rho(v_j^*) - \rho(v_i^0)\rho(v_j^0)\]

设计细节#

图2. 网络架构图 (a) 的所示,为传统的深度学习网络结构图,其中 \(I\) 为网络的输入, \(T\) 为网络的输出, \(f_i\) 为网络的第 \(i\) 层。然而,以这种形式创建网络在实现局部学习时存在几个挑战,如复杂的误差和梯度计算,极低的运行效率和有限的可扩展性。

Network Structure

图2. 网络架构图#

为了降低用户实现难度,有效提高运行效率,增强局部学习模块的可扩展行性。本模块做了以下改进:

  • (1): 在网络创建阶段

    图2. 网络架构图 (b) 所示,该模块将网络创建抽象为边( \(f_i\) )和节点( \(v_i\) )的形式,其中 \(f\) 表示网络的前向, \(v\) 就是这些层的输出和网络的输入 \(v_0\) 构成的集合。因此,在深度学习算子的基础上,增加了 adj_mat 属性,以存储与该边相连的点,然后将网络结构以邻接矩阵的形式组织。如 图2. 网络架构图 (b) 的跳连网络所示,邻接矩阵如下:

    \[\begin{split}\left[\matrix{ {f_0} & [0, 1] \\ {f_1} & [[1, 2], [1, 3]] \\ {f_2} & [2, 3] \\ {f_3} & [3, 4]}\right]\end{split}\]

    对于 图2. 网络架构图 中的单向、双向网络,具体网络搭建过程如下:

    from neurai import nn
    
    class MLP(nn.RModule):
    
      def setup(self):
        self.fc1 = nn.RLinear(256, adj_mat=[0, 1])
        self.fc2 = nn.RLinear(128, adj_mat=[1, 2])
        self.fc3 = nn.RLinear(128, adj_mat=[2, 3])
        self.fc4 = nn.RLinear(10, adj_mat=[3, 4])
    
      def __call__(self, input):
        y = self.fc1(input)
        y = self.fc2(y)
        y = self.fc3(y)
        y = self.fc4(y)
        return y
    

    对于 图2. 网络架构图 中的跳连网络,具体搭建过程如下:

    from neurai import nn
    
    class CorMLP(nn.RModule):
    
      def setup(self):
        self.fc1 = nn.RLinear(256, adj_mat=[0, 1])
        self.fc2 = nn.RLinear(128, adj_mat=[[1, 2], [1, 3]])
        self.fc3 = nn.RLinear(128, adj_mat=[2, 3])
        self.fc4 = nn.RLinear(10, adj_mat=[3, 4])
    
      def __call__(self, input):
        y0 = self.fc1(input)
        y1 = self.fc2(y0)
        y2 = self.fc3(y1)
        y3 = self.fc4(y2 + y1)
        return y3
    
  • (2): 在网络学习阶段

    为了简化用户操作及针对不同的应用场景,我们提供了 neurai.grads.autograd.PCneurai.grads.autograd.BiPCneurai.grads.autograd.EP 三个梯度求解类,用户只需要调用对应的 get_grads 方法就能完成梯度计算任务。

    此外,为了提高模块扩展性,用户可以按照自己的应用场景,选择所需继承的梯度求解类,并重写能量函数 _energy ,然后调用 get_grads 方法。

    from neurai.grads import EP
    
    class CustomEP(EP):
    
      def _energy(self, *args, **kwargs):
        # add your code.
        pass
    

创建与训练局部学习模型#

下载与加载数据集#

用户可以调用 neurai.datasets 包中相应数据集类(如 neurai.datasets.cifar.CIFAR10 ),将数据集下载到指定目录 DATASETS_DIR

from neurai.datasets import CIFAR10, DataLoader, Compose, Normalize

def get_train_loader(dir_path=DATASETS_DIR, train=True, batch_size=128, shuffle=True, download=True):
  train_data = CIFAR10(
    dir_path,
    train=train,
    download=download,
    transform=Compose([Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]))
  return DataLoader(train_data, batch_size, shuffle, drop_last=True)

def get_test_loader(dir_path=DATASETS_DIR, train=False, batch_size=128, download=True):
  test_data = CIFAR10(
    dir_path,
    train=train,
    download=download,
    transform=Compose([Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]))
  return DataLoader(test_data, batch_size, drop_last=True)

定义网络模型#

定义网络时,需要继承 neurai.nn.rlayer.rmodule.RModule 类,通过 neurai.nn.rlayer.rlinear.RLinearneurai.nn.rlayer.ractivate.RRelu 等其它的算子实现网络构建。在 __call__ 方法中定义网络的前向过程。

from neurai import nn

class ConvDemo(nn.RModule):

 def setup(self):
   self.seq1 = nn.RSequential([
     nn.RConv2d(features=10, kernel_size=3, padding=0, parent=None),
     nn.RTanh(),
     nn.RMaxPool((2, 2), (1, 1), parent=None),
   ], adj_mat=[0, 1])
   self.seq2 = nn.RSequential([
     nn.RConv2d(features=5, kernel_size=3, padding=0, parent=None),
     nn.RTanh(), nn.RFlatten(parent=None)
   ], adj_mat=[1, 2])
   self.seq3 = nn.RSequential([nn.RLinear(50, parent=None), nn.RTanh()], adj_mat=[2, 3])
   self.seq4 = nn.RSequential([nn.RLinear(30, parent=None), nn.RTanh()], adj_mat=[3, 4])
   self.seq5 = nn.RSequential([nn.RLinear(10, parent=None)], adj_mat=[4, 5])

 def __call__(self, input):
   y = self.seq1(input)
   y = self.seq2(y)
   y = self.seq3(y)
   y = self.seq4(y)
   return self.seq5(y)

创建优化器与求导器#

通过 neurai.opt 创建优化器。

from neurai.opt import adam, apply_updates
optim = adam(1e-4)

通过 neurai.grads.autograd.PCneurai.grads.autograd.BiPCneurai.grads.autograd.EPneurai.const.train.PCModeneurai.const.train.EPMode 来创建梯度求解方法。

from neurai.grads.autograd import PC
from neurai.const import PCMode
from neurai.nn import loss
from neurai.util.trans import jit


def accuracy(predict, target):
  return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))

grad_op = PC(loss_f=loss.softmax_cross_entropy, acc_f=accuracy, mode=PCMode.STRICT)

模型训练与测试#

# create and initialize model.
net = ConvDemo()
param = net.init(input=jnp.ones([128, 32, 32, 3]))
out = net.run
model = net.bind(param)

# initialize opt_state and get test_loader, train_loader.
opt_state = optim.init(param)
test_loader = get_test_loader()
train_loader = get_train_loader()

# apply jit to grad_op.get_grads
get_grads = jit(grad_op.get_grads, static_argnums=(0,))

# training and testing
for i in range(5):
  # training
  for _, (data, label) in enumerate(train_loader):
    grads, train_loss, train_acc = get_grads(model, param,
    batch_data=(jnp.asarray(data), jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)))
    updates, opt_state = optim.update(grads, opt_state, param)
    param = apply_updates(param, updates)
    print("{}/{}: train_acc:{:.4f} \t train_loss:{:.4f}".format(i, _, train_acc, train_loss))

  # testing
  test_loss_list, test_acc_list = [], []
  for _, (data, label) in enumerate(test_loader):
    label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
    predict = out(param, data)
    test_acc, test_loss = grad_op.acc_f(predict, label), grad_op.loss_f(predict, label)
    test_loss_list.append(test_loss)
    test_acc_list.append(test_acc)
  print("{}: test_acc:{:.4f} \t test_loss:{:.4f}".format(i, jnp.mean(test_acc), jnp.mean(test_loss)))

完整例子如下:

from neurai import nn
from neurai.nn import loss
from neurai.datasets import CIFAR10, DataLoader, Compose, Normalize
from neurai.opt import adam, apply_updates
from neurai.grads.autograd import PC
from neurai.const import PCMode
from neurai.util.trans import jit
import jax.numpy as jnp

DATASETS_DIR = '/datasets/'


class ConvDemo(nn.RModule):

  def setup(self):
    self.seq1 = nn.RSequential([
      nn.RConv2d(features=10, kernel_size=3, padding=0, parent=None),
      nn.RTanh(),
      nn.RMaxPool((2, 2), (1, 1), parent=None),
    ], adj_mat=[0, 1])
    self.seq2 = nn.RSequential([
      nn.RConv2d(features=5, kernel_size=3, padding=0, parent=None),
      nn.RTanh(), nn.RFlatten(parent=None)
    ], adj_mat=[1, 2])
    self.seq3 = nn.RSequential([nn.RLinear(50, parent=None), nn.RTanh()], adj_mat=[2, 3])
    self.seq4 = nn.RSequential([nn.RLinear(30, parent=None), nn.RTanh()], adj_mat=[3, 4])
    self.seq5 = nn.RSequential([nn.RLinear(10, parent=None)], adj_mat=[4, 5])

  def __call__(self, input):
    y = self.seq1(input)
    y = self.seq2(y)
    y = self.seq3(y)
    y = self.seq4(y)
    return self.seq5(y)

def get_train_loader(dir_path=DATASETS_DIR, train=True, batch_size=128, shuffle=True, download=True):
  train_data = CIFAR10(
    dir_path,
    train=train,
    download=download,
    transform=Compose([Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]))
  return DataLoader(train_data, batch_size, shuffle, drop_last=True)


def get_test_loader(dir_path=DATASETS_DIR, train=False, batch_size=128, download=True):
  test_data = CIFAR10(
    dir_path,
    train=train,
    download=download,
    transform=Compose([Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]))
  return DataLoader(test_data, batch_size, drop_last=True)


def accuracy(predict, target):
  return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))


if __name__ == "__main__":
  # create and initialize model.
  net = ConvDemo()
  param = net.init(input=jnp.ones([128, 32, 32, 3]))
  out = net.run
  model = net.bind(param)

  # initialize opt_state and get test_loader, train_loader.
  optim = adam(1e-4)
  opt_state = optim.init(param)
  test_loader = get_test_loader()
  train_loader = get_train_loader()

  # create the PC class add apply jit to get_grads
  grad_op = PC(loss_f=loss.softmax_cross_entropy, acc_f=accuracy, mode=PCMode.STRICT)
  get_grads = jit(grad_op.get_grads, static_argnums=(0,))

  for i in range(5):
    # training
    for _, (data, label) in enumerate(train_loader):
      grads, train_loss, train_acc = get_grads(model, param,
      batch_data=(jnp.asarray(data), jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)))
      updates, opt_state = optim.update(grads, opt_state, param)
      param = apply_updates(param, updates)
      print("{}/{}: train_acc:{:.4f} \t train_loss:{:.4f}".format(i, _, train_acc, train_loss))

    # testing
    test_loss_list, test_acc_list = [], []
    for _, (data, label) in enumerate(test_loader):
      label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
      predict = out(param, data)
      test_acc, test_loss = grad_op.acc_f(predict, label), grad_op.loss_f(predict, label)
      test_loss_list.append(test_loss)
      test_acc_list.append(test_acc)
    print("{}: test_acc:{:.4f} \t test_loss:{:.4f}".format(i, jnp.mean(test_acc), jnp.mean(test_loss)))