模型构建#

作者: MinYao Ni

本文档介绍了如何使用 `NeurAI` 构建深度神经网络。

neurai.nn介绍#

为了定制化深度神经网络,可以使用 neurai.nn 下的API搭建模型,这个目录下定义了丰富的网络层和API,如卷积层相关的 Conv1dConv2d 等,池化层相关的 AvgPoolMaxPool 等,具体见API文档。
neurai 建议以继承类的方式构建网络,并提供了 neurai.nn.Module 作为网络层基类;除此之外,针对一些结构较为简单的网络,也提供了 neurai.nn.Sequential 接口用于快速构建网络。
  • 使用 neurai.nn.Sequential :构建简单的线性(如无跳跃连接的网络)网络结构时,可以选择这个方式,更加简单、代码量更少。

  • 使用 neurai.nn.Module :构建较为复杂的(如跳跃连接)网络结构时,可以选择这个方式。可以更加灵活的构建各类复杂网络,自定义增加除网络层以外的计算逻辑,也可以将Sequential构建的网络作为网络层加入。

neurai.nn.Sequential构建网络#

使用 neurai.nn.Sequential 构建网络时,需要按照模型结构,将网络层按顺序添加到一个 Sequence 中,将这个 Sequence 放到 neurai.nn.Sequential 中即可。
使用 neurai.nn.Sequential 构建LeNet模型的示例代码如下:
import neurai.nn as nn

lenet = nn.Sequential(
  [
    nn.Conv2d(6, 3, padding=(1, 1)),
    nn.Relu(),
    nn.MaxPool((2,2), (2,2)),
    nn.Conv2d(16, 5, padding=(0, 0)),
    nn.Relu(),
    nn.MaxPool((2,2), (2,2)),
    nn.Flatten(),
    nn.Linear(120),
    nn.Linear(84),
    nn.Linear(10),
  ]
)

这种方式在推理时,会按照网络层堆叠顺序完成网络的前向计算过程,因此只能完成简单的线性模型,更复杂的模型建议使用 neurai.nn.Module 形式构建网络结构。

neurai.nn.Module构建网络#

构建较为复杂的网络结构时,可以选择本方案,主要包括三个步骤:
1.创建一个继承自 neurai.nn.Module 的类
2.在 setup 函数中定义需要的网络层
3.在 __call__ 函数中使用定义好的网络层执行前向计算
依旧以LeNet为例,构建网络代码如下:
import neurai.nn as nn

class LeNet(nn.Module):
  num_class:int = 10

  def setup(self):
    self.features = nn.Sequential(
                      [
                        nn.Conv2d(6, 3, padding=(1, 1)),
                        nn.Relu(),
                        nn.MaxPool((2,2), (2,2)),
                        nn.Conv2d(16, 5, padding=(0, 0)),
                        nn.Relu(),
                        nn.MaxPool((2,2), (2,2)),
                        nn.Flatten(),
                        nn.Linear(120),
                        nn.Linear(84)
                      ]
                    )
    if self.num_class > 0:
      self.linear = nn.Linear(self.num_class)

  def __call__(self, input):
    x = self.features(input)
    if self.num_class>0:
      x = self.linear(x)
    return x

lenet=LeNet()

在上面的代码中,将LeNet分为了features和linear两个部分,features用于提取深层特征,linear用于分类。