搭建简单的ANN模型#
作者: LuckyHFC
首先,需要导入 `NeurAI`
包。
用户可以根据需求,使用 neurai.config.set_platform
选择设备, 可以选择 'cpu'
或 'gpu'
。
from neurai.config import set_platform
set_platform(platform="gpu")
下载数据集#
通过 neurai.datasets.mnist.MNIST
类创建和下载训练和测试数据集,数据集将下载到目录 “./” 。
from neurai import datasets
train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
加载数据集#
通过传入一些参数(如 dataset
batch_size
and shuffle
等)去实例化 neurai.datasets.dataloader.DataLoader
类,为 train_data
和 test_data
创建迭代器。
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)
定义网络#
继承 neurai.nn.module.Module
类来创建神经网络, 然后通过 neurai.nn.layer.linear.Linear
, neurai.nn.layer.activate.Relu
等网络层API来构建网络结构, 在 __call__
中定义前向计算过程。
from neurai import nn
class MLP(nn.Module):
def setup(self):
self.fc1 = nn.Linear(256)
self.fc2 = nn.Linear(128)
self.fc3 = nn.Linear(10)
self.relu = nn.Relu()
def __call__(self, inputs):
inputs = inputs.reshape(inputs.shape[0], -1)
y = self.relu(self.fc1(inputs))
y = self.relu(self.fc2(y))
y = self.fc3(y)
return y
优化器和梯度下降#
通过 neurai.opt
来创建优化器。
from neurai.opt import sgd, apply_updates
optim = sgd(0.001)
使用 neurai.nn.loss
创建损失函数。
loss_fun = loss.softmax_cross_entropy(out, n_label)
定义一个正确率函数来计算正确率。
def accuracy(predict, target):
return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))
初始化模型#
在训练和验证之前,需要实例化创建的模型,然后初始化这个模型以获得初始权重。
net = MLP()
params = net.init(jnp.ones((1, 32, 32, 3)))
opt_state = optim.init(params['param'])
训练和验证#
@jax.jit
def train_step(params, batch, opt_state):
data, label = batch
def loss(params, data):
out, params = net.run(params, data, return_param=True)
n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
acc = accuracy(out, n_label)
loss_val = softmax_cross_entropy(out, n_label)
return loss_val, (acc, params)
(loss_val, (acc_val, params)), grads = jax.value_and_grad(loss, has_aux=True)(params, data)
updates, opt_state = optim.update(grads['param'], opt_state, params['param'])
newparam = apply_updates(params['param'], updates)
params['param'] = newparam
return params, opt_state, loss_val, acc_val
@jax.jit
def test_step(params, batch):
data, label = batch
out = net.run(params, data, train=False)
n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
acc_val = accuracy(out, n_label)
loss_val = softmax_cross_entropy(out, n_label)
return loss_val, acc_val
整体示例代码如下:
import tqdm
from neurai.config import set_platform
set_platform(platform='gpu')
from neurai.nn import Module, Linear, Relu
from neurai import datasets, opt
from neurai.nn.loss import softmax_cross_entropy
import jax.numpy as jnp
import neurai
import jax
class MLP(Module):
def setup(self):
self.fc1 = Linear(256)
self.fc2 = Linear(128)
self.fc3 = Linear(10)
self.relu = Relu()
def __call__(self, input):
y1 = self.relu(self.fc1(input=input.reshape(input.shape[0], -1)))
y2 = self.relu(self.fc2(y1))
y3 = self.fc3(y2)
return y3
def accuracy(predict, target):
return jnp.mean(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))
@jax.jit
def train_step(params, batch, opt_state):
data, label = batch
def loss(params, data):
out, params = net.run(params, data, return_param=True)
n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
acc = accuracy(out, n_label)
loss_val = softmax_cross_entropy(out, n_label)
return loss_val, (acc, params)
(loss_val, (acc_val, params)), grads = jax.value_and_grad(loss, has_aux=True)(params, data)
updates, opt_state = optim.update(grads['param'], opt_state, params['param'])
newparam = opt.apply_updates(params['param'], updates)
params['param'] = newparam
return params, opt_state, loss_val, acc_val
@jax.jit
def test_step(params, batch):
data, label = batch
out = net.run(params, data)
n_label = jnp.asarray(label[:, None] == jnp.arange(10), jnp.float32)
acc_val = accuracy(out, n_label)
loss_val = softmax_cross_entropy(out, n_label)
return loss_val, acc_val
train_data = datasets.MNIST(root="./", train=True, download=True)
test_data = datasets.MNIST(root="./", train=False, download=True)
train_loader = datasets.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = datasets.DataLoader(dataset=test_data, batch_size=32)
optim = opt.sgd(0.001)
net = MLP()
params = net.init(input=jnp.ones((1, 28, 28, 1)))
opt_state = optim.init(params['param'])
num_epochs = 10
for epoch in range(num_epochs):
with tqdm.tqdm(train_loader) as tepoch:
tepoch.set_description(f"Training/epoch {epoch}")
for batch in tepoch:
params, opt_state, loss_val, acc_val = train_step(params, batch, opt_state)
tepoch.set_postfix(loss=loss_val, acc=acc_val)
total_loss, total_acc = 0.0, 0.0
with tqdm.tqdm(test_loader) as tepoch:
tepoch.set_description("Testing")
for batch in tepoch:
loss_val, acc_val = test_step(params, batch)
total_acc += acc_val
total_loss += loss_val
total_acc /= len(test_loader)
total_loss /= len(test_loader)
print(f"Test {total_loss=}, {total_acc=}")