模型保存与加载#
作者: MinYao Ni
模型训练后,模型参数一般保存在内存中,需要将模型参数以文件形式保存到磁盘中,才能持久保存模型参数,并在后续的训练微调或推理部署时加载到内存或指定设备上运行。这里介绍本框架下模型保存与加载的方法
模型在训练推理等流程中,需要使用保存与加载功能的场景有:
模型在训练中定期保存模型,方便后续对不同训练阶段的模型继续训练或者进行效果研究
模型训练结束,保存模型方便评估和后续推理使用
加载预训练的模型,对模型进行微调或测试评估
加载训练好的模型进行服务端部署,提供相应服务
针对这些场景,目前本框架提供了保存和加载模型参数的API: neurai.util.save
和 neurai.util.restore
使用 neurai.util.save
保存模型参数,还是以前几章节中训练好的模型为例,代码如下:
# 设置模型保存权重的位置与文件名,这里保存至 "./lenet.bin"
# 将模型权重model_params保存
neurai.util.save("./lenet.bin", model_params)
使用 neurai.util.restore
加载保存的模型参数,并用于推理,代码如下:
param = neurai.util.restore("lenet.bin")
test_loader = datasets.DataLoader(test_dataset, batch_size=1, shuffle=True, drop_last=True)
for data in test_loader:
img, label = data
img = img.astype(jnp.float32)
predict = lenet.run(param, img)
predict_label = jnp.argmax(predict[0])
print("true label: {}, predict label: {}".format(label, predict_label))
break
输出结果为:
true label: [6], predict label: 6
Note
neurai.util.save
和 neurai.util.restore
这两个API实际是基于Python msgpack库实现的。
msgpack的数据格式与JSON类似,可以跨平台、跨操作系统、支持多种语言,在多种语言之间使用,高效压缩。 但msgpack比JSON更快更小。