参数冻结#

作者: LuckyHFC & joyeewang

参数冻结是训练时常用的功能,主要有以下应用场景:

  • 迁移学习

    在迁移学习任务中,通过冻结预训练的模型参数,可以将学习到的知识转移到新的任务中, 避免了从头开始训练,加快了迁移任务的过程。

  • 加速模型训练

    在模型训练过程中,冻结的模型参数不需要更新和优化,并且减少了可训练参数的数量, 从而降低了模型的整体复杂性,有效地加快了模型训练速度,减少了资源开销。

  • 减少过拟合

    虽然冻结部分模型参数会降低模型对训练数据的灵活性,但有助于提高模型对新数据的泛化能力。通过冻结部分参数, 可以更好地提高模型的泛化能力,增强模型的鲁棒性。

    此外,在小样本学习任务中,容易导致模型与训练数据过拟合,学习到的特征对新数据的泛化能力有限。 通过冻结部分模型参数,可以限制模型的复杂性,降低过拟合的风险。

因此,本节将介绍如何在 NeurAI 框架中实现参数冻结任务。在介绍参数冻结之前,用户首先要了解 NeurAI 中的参数结构,详见 快速入门 NeurAI。 冻结参数作用于可训练模型变量参数。NeurAI中的冻结参数可分为以下两部分:

  • 网络定义/初始化阶段

    在某些场景下,用户需要冻结自定义算子或模型中的一些变量参数,如果这个算子继承自某个父类算子,如conv1d继承自conv,可能需要子类能够冻结父类的参数。 在此类场景下,用户可以使用 frozen_params={"key": value} 来表示实例化时需要冻结哪些参数,其中 key 是需要冻结的参数名称, value 表示参数是否被冻结。

    from neurai import nn
    
    
    class Model(nn.Module):
      def setup(self):
        self.fc1 = nn.Linear(784, frozen_params={"weight": True})
        self.fc2 = nn.Linear(256)
        self.fc3 = nn.Linear(10)
    
      def __call__(self, x):
        return self.fc3(self.fc2(self.fc1(x)))
    
  • 模型训练阶段

    对于在模型定义时设置了 frozen_params={"key": value} 的模型参数,梯度更新时会自动忽略这些参数。

    但还有一种常见场景, 用户可能在迭代的某个阶段需要动态地冻结参数,而不是在初始化时指定,要实现此场景的冻结功能, 用户需在训练时根据自己的需求,将对应参数的梯度值设置为None。具体示例如下:

    ...
    grads, loss, acc = get_grads(model.run, param, (jnp.reshape(data, (128, -1)), one_hot(label, 10)))
    grads["param"]["Linear_0"]["weight"]=None
    updates, opt_state = optim.update(grads["param"], opt_state)
    param["param"] = apply_updates(param["param"], updates)
    ...
    

Note

为了更好地支持JIT, NeurAI将网络模型与其参数完全解耦,这也使得用户难以寻找它们的对应关系。目前提供以下解决方案:

  • 用户可以在创建网络时主动指定layer名称,这样在创建参数字典时会使用用户指定的名称。

  • 在调试模式下,网络初始化后,预先查看实例化模型对象与参数字典的对应关系。

当然,解决方案可能缺乏用户友好性,将来可能会出现更好的解决方案。