模型保存与加载#

作者: MinYao Ni

模型训练后,模型参数一般保存在内存中,需要将模型参数以文件形式保存到磁盘中,才能持久保存模型参数,并在后续的训练微调或推理部署时加载到内存或指定设备上运行。这里介绍本框架下模型保存与加载的方法
模型在训练推理等流程中,需要使用保存与加载功能的场景有:
  • 模型在训练中定期保存模型,方便后续对不同训练阶段的模型继续训练或者进行效果研究

  • 模型训练结束,保存模型方便评估和后续推理使用

  • 加载预训练的模型,对模型进行微调或测试评估

  • 加载训练好的模型进行服务端部署,提供相应服务

针对这些场景,目前本框架提供了保存和加载模型参数的API: neurai.util.saveneurai.util.restore

使用 neurai.util.save 保存模型参数,还是以前几章节中训练好的模型为例,代码如下:

# 设置模型保存权重的位置与文件名,这里保存至 "./lenet.bin"
# 将模型权重model_params保存
neurai.util.save("./lenet.bin", model_params)

使用 neurai.util.restore 加载保存的模型参数,并用于推理,代码如下:

param = neurai.util.restore("lenet.bin")
test_loader = datasets.DataLoader(test_dataset, batch_size=1, shuffle=True, drop_last=True)
for data in test_loader:
  img, label = data
  img = img.astype(jnp.float32)
  predict = lenet.run(param, img)
  predict_label = jnp.argmax(predict[0])
  print("true label: {}, predict label: {}".format(label, predict_label))
  break

输出结果为:

true label: [6], predict label: 6

Note

neurai.util.saveneurai.util.restore 这两个API实际是基于Python msgpack库实现的。

msgpack的数据格式与JSON类似,可以跨平台、跨操作系统、支持多种语言,在多种语言之间使用,高效压缩。 但msgpack比JSON更快更小。