模型微调#

微调是指在已有预训练模型的基础上,继续训练模型,以适应特定任务。微调可以提升模型的性能,并减少训练时间。

本文将介绍如何在 NeurAI 框架中使用微调功能。以目前较为火热的大语言模型为例,提供LoRA和Prefix-tuning两种微调方法的调用示例。

LoRA#

LoRA(Low-Rank Adaptation of Large Language Models),大语言模型的低阶适应,这是为了解决大语言模型微调而开发的一项技术。
LoRA会冻结预训练好的模型权重参数,对目标层注入可训练的层,由于不需要对模型的权重参数重新计算梯度,所以大大减少了需要训练的计算量。
NeurAI 框架中,LoRA的调用方法步骤如下:
  • 设置微调参数,实例化微调模型

from neurai.finetune import get_tuning_config, get_tuning_model

args = {
  "tuning_tpye": "lora",
  "r": 8,
  "target_modules": ["c_attn"],
}

# 获取微调设置
tuning_config = get_tuning_config(args)
# 构建微调模型
tuning_model = get_tuning_model(tuning_config)
这里使用了 get_tuning_config 函数来设置微调参数,构建了微调的参数实例。使用 get_tuning_model 函数实例化微调模型。
  • 初始化微调模型

from neurai.util.serialization import restore
from qwen.qwen_modeling import QWenLMHeadModule
from qwen.qwen_config import QWenConfig

# 创建优化器
n_steps = math.ceil(len(dataloader) / n_accumulation_steps)
schedule = exponential_decay(lr, n_steps * n_epochs, decay_rate=0.99)
optimizer = adamw(learning_rate=schedule)
optimizer = MultiSteps(optimizer, n_accumulation_steps)

# 加载模型预训练好的权重
params = restore("path/to/pretrained_model.bin")

# 加载模型参数
with open("path/to/generation_config.json", "r") as f:
  generation_params = json.loads(f.read())
with open("path/to/config.json", "r") as f:
  params = json.loads(f.read())

params.update(generation_params)
qwen_config = QWenConfig(**params)

# 初始化微调模型
qwen_model, qwen_model_run, params, opt_state = tuning_model.init_tuning(QWenLMHeadModule, qwen_config, params, optimizer)
在这一步,创建了优化器,加载了预训练好的权重和模型的配置参数。将这些参数以及待训练的网络传入 init_tuning 函数中,初始化微调模型。
init_tuning 函数中,会冻结原始模型的权重参数,并在目标层增加LoRA节点,同时根据更新后权重状态初始化优化器状态。
  • 微调模型训练

# 定义训练步骤
@jax.value_and_grad
def train_forward(params, data_batch: TrainData):
  seq, seq_mask, labels, labels_mask = data_batch
  _, input_len = seq.shape
  position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seq).shape[-1]), seq.shape)
  outputs = qwen_model_run(
    params,
    seq,
    seq_mask,
    position_ids,
    labels=labels,
    input_len=input_len,
    return_dict=False,
    use_cache=False,
  )
  losses = outputs[0]
  losses = jnp.mean(losses, where=labels_mask[..., 1:], axis=1)
  loss = jnp.mean(losses)
  return loss


@jax.jit
def train_step(params, opt_state, total_loss: jnp.array, data_batch: TrainData):
  loss, grads = train_forward(params, data_batch)
  total_loss += loss
  params, opt_state = tuning_model.update_params(params, grads, opt_state)
  return params, opt_state, total_loss, loss

# 准备数据集
config_kwargs = {
  "trust_remote_code": True,
  "cache_dir": None,
  "revision": 'main',
  "use_auth_token": None,
}
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=False, padding_side="right", **config_kwargs)

if tokenizer.eos_token is None:
  tokenizer.eos_token = "<|endoftext|>"
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.eos_token
data_dir = args.data_dir
dataset = SFTDataset(data_dir, "alpaca_gpt4_data_en.json", tokenizer, int(max_len / 2), int(max_len / 2))
collate_fn = partial(sft_collate_fn_train, tokenizer=tokenizer, max_len=max_len)
dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)

# 训练
for epoch in range(n_epochs):
  total_loss = jnp.zeros(())
  with tqdm(dataloader) as tepoch:
    tepoch.set_description(f"Training/epoch {epoch}")
    for batch in tepoch:
      params, opt_state, total_loss, loss = train_step(params, opt_state, total_loss, batch)
      tepoch.set_postfix(loss=loss)
在这个步骤中,定义了训练步骤,准备了数据集,并训练了模型。
训练步骤中, train_forward 函数会进行模型推理,并计算损失。利用 jax.value_and_grad 装饰器,计算损失函数的梯度,并返回损失值和梯度。
train_step 函数中,将计算得到的梯度传入 tuning_model.update_params 函数中,更新LoRA节点的参数。
  • 保存LoRA权重

tuning_model.save_finetune_weight(params)
在这一步,通过调用 tuning_model.save_finetune_weight 函数,保存了LoRA节点的权重。
保存下来的权重占用的磁盘容量相比一个完整的模型权重小很多。方便针对不同任务保存不同的权重,即插即用。
  • 加载LoRA权重

lora_params = restore("path/to/finetuned_model.bin")
params = restore("path/to/finetuned_model.bin")
# 合成带LoRA节点的权重
all_params = tuning_model.restore_tuning_weight(lora_params, params)
# 合并权重
params = tuning_model.merge_params(all_params)
在这一步,加载了LoRA节点的权重,并使用 restore_tuning_weight 函数生成带LoRA节点的权重。
另外也可以使用 merge_params 函数合并LoRA节点和原始权重。经过这个函数后,LoRA节点的权重会被完全融入到模型中,无法再单独保存,原模型权重也不再存在。
合并后权重和带LoRA节点的权重都可以直接用于模型推理。

Prefix-tuning#

Prefix-tuning是一种通过在模型输入中添加特定的前缀,来引导模型生成特定类型的输出。
prefix-tuning在模型输入前添加一个连续的且任务特定的向量序列称之为prefix,固定预训练语言模型(PLM)的所有参数,只更新优化特定任务的prefix。这种微调方案仅适用于PLM。
NeurAI 框架中,Prefix-tuning的调用流程大部分和LoRA一致。
需要在初始化微调模型时,传入 tuning_type 参数改为 prefix_tuning, 并设定 num_virtual_tokens 参数,这表示需要添加的虚拟token数量。如下
  args = {
  "tuning_tpye": "prefix_tuning",
  "num_virtual_tokens": 30,
}

# 获取微调设置
tuning_config = get_tuning_config(args)
# 构建微调模型
tuning_model = get_tuning_model(tuning_config)
训练步骤与LoRA一致。
保存权重和加载权重的步骤也与LoRA一致。
注意:Prefix-tuning和LoRA的区别在于,Prefix-tuning无 merge_params 函数,并且只能使用 tuning_model.init_tuning 初始化的模型进行推理。推理流程如下:
from eval.eval_sft import eval_sft

# 定义推理步骤
def test(qwen_model, params, config, tokenizer):
  # 实例化模型
  mh_model = QWenLMHeadModel(config, _do_init=False)
  # 替换网络模型
  mh_model.module = qwen_model
  # 实例化语言生成器
  model_gen = QWen(
    params=params, model=mh_model, tokenizer=tokenizer, generation_config=GenerationConfig.from_model_config(config))
  model_gen.generation_config.do_sample = False
  # 测试
  acc_val = eval_sft(model_gen, tokenizer)

  return acc_val

# 设置推理模式
qwen_model.tuning_config.inference_mode = True
new_acc = test(qwen_model, params['param'], qwen_config, tokenizer)
在推理阶段,需要将模型的 tuning_config.inference_mode 设置为 True,以便使用推理模式进行推理,不需要更新Prefix层参数。

Note

本框架提供了以Qwen模型为例的LoRA和Prefix-tuning的微调示例。

具体示例代码见examples/ann/LLM/qwen/finetune_qwen.py。