搭建简单的ANN模型#

作者: LuckyHFC

首先,需要导入 `NeurAI` 包。 用户可以根据需求,使用 neurai.config.set_platform 选择设备, 可以选择 'cpu''gpu'

from neurai.config import set_platform

set_platform(platform="gpu")

下载数据集#

通过 neurai.datasets.mnist.MNIST 类创建和下载训练和测试数据集,数据集将下载到目录 “./” 。

from neurai import datasets

train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)

加载数据集#

通过传入一些参数(如 dataset batch_size and shuffle 等)去实例化 neurai.datasets.dataloader.DataLoader 类,为 train_datatest_data 创建迭代器。

train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)

定义网络#

继承 neurai.nn.module.Module 类来创建神经网络, 然后通过 neurai.nn.layer.linear.Linear, neurai.nn.layer.activate.Relu 等网络层API来构建网络结构, 在 __call__ 中定义前向计算过程。

from neurai import nn


class MLP(nn.Module):

  def setup(self):
    self.fc1 = nn.Linear(256)
    self.fc2 = nn.Linear(128)
    self.fc3 = nn.Linear(10)
    self.relu = nn.Relu()

  def __call__(self, inputs):
    inputs = inputs.reshape(inputs.shape[0], -1)
    y = self.relu(self.fc1(inputs))
    y = self.relu(self.fc2(y))
    y = self.fc3(y)
    return y

优化器和梯度下降#

通过 neurai.opt 来创建优化器。

from neurai.opt import sgd, apply_updates

optim = sgd(0.001)

使用 neurai.nn.loss 创建损失函数。

loss_fun = loss.softmax_cross_entropy(out, n_label)

定义一个正确率函数来计算正确率。

def accuracy(predict, target):
  return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))

初始化模型#

在训练和验证之前,需要实例化创建的模型,然后初始化这个模型以获得初始权重。

net = MLP()
params = net.init(jnp.ones((1, 32, 32, 3)))
opt_state = optim.init(params['param'])

训练和验证#

@jax.jit
def train_step(params, batch, opt_state):
  data, label = batch

  def loss(params, data):
    out, params = net.run(params, data, return_param=True)
    n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
    acc = accuracy(out, n_label)
    loss_val = softmax_cross_entropy(out, n_label)
    return loss_val, (acc, params)

  (loss_val, (acc_val, params)), grads = jax.value_and_grad(loss, has_aux=True)(params, data)
  updates, opt_state = optim.update(grads['param'], opt_state, params['param'])
  newparam = apply_updates(params['param'], updates)
  params['param'] = newparam
  return params, opt_state, loss_val, acc_val

@jax.jit
def test_step(params, batch):
  data, label = batch
  out = net.run(params, data, train=False)
  n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
  acc_val = accuracy(out, n_label)
  loss_val = softmax_cross_entropy(out, n_label)
  return loss_val, acc_val

整体示例代码如下:

import tqdm
from neurai.config import set_platform
set_platform(platform='gpu')

from neurai.nn import Module, Linear, Relu
from neurai import datasets, opt
from neurai.nn.loss import softmax_cross_entropy
import jax.numpy as jnp
import neurai
import jax


class MLP(Module):

  def setup(self):
    self.fc1 = Linear(256)
    self.fc2 = Linear(128)
    self.fc3 = Linear(10)
    self.relu = Relu()

  def __call__(self, input):
    y1 = self.relu(self.fc1(input=input.reshape(input.shape[0], -1)))
    y2 = self.relu(self.fc2(y1))
    y3 = self.fc3(y2)
    return y3


def accuracy(predict, target):
  return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))


@jax.jit
def train_step(params, batch, opt_state):
  data, label = batch

  def loss(params, data):
    out, params = net.run(params, data, return_param=True)
    n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
    acc = accuracy(out, n_label)
    loss_val = softmax_cross_entropy(out, n_label)
    return loss_val, (acc, params)

  (loss_val, (acc_val, params)), grads = jax.value_and_grad(loss, has_aux=True)(params, data)
  updates, opt_state = optim.update(grads['param'], opt_state, params['param'])
  newparam = opt.apply_updates(params['param'], updates)
  params['param'] = newparam
  return params, opt_state, loss_val, acc_val


@jax.jit
def test_step(params, batch):
  data, label = batch
  out = net.run(params, data)
  n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
  acc_val = accuracy(out, n_label)
  loss_val = softmax_cross_entropy(out, n_label)
  return loss_val, acc_val


train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)
optim = opt.sgd(0.001)
net = MLP()
params = net.init(input=jnp.ones((1, 28, 28, 1)))
opt_state = optim.init(params['param'])
num_epochs = 10
for epoch in range(num_epochs):
  with tqdm.tqdm(train_loader) as tepoch:
    tepoch.set_description(f"Training/epoch {epoch}")
    for batch in tepoch:
      params, opt_state, loss_val, acc_val = train_step(params, batch, opt_state)
      tepoch.set_postfix(loss=loss_val, acc=acc_val)

total_loss, total_acc = 0.0, 0.0
with tqdm.tqdm(test_loader) as tepoch:
  tepoch.set_description("Testing")
  for batch in tepoch:
    loss_val, acc_val = test_step(params, batch)
    total_acc += acc_val
    total_loss += loss_val

total_acc /= len(test_loader)
total_loss /= len(test_loader)
print(f"Test {total_loss=}, {total_acc=}")