0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看威廉希尔官方网站 视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

RLHF实践中的框架使用与一些坑 (TRL, LMFlow)

深度学习自然语言处理 来源:Hugging Face 2023-06-20 14:36 次阅读

1 前言

之前看见文章总结了常见的一些 RLHF 框架的经验, 但是似乎没看见 Hugging Face 自己维护的 TRL 库的相关文章, 正好最近调 TRL 比较多, 就想写一个文章分享一下使用过程中踩到的坑,另外也介绍一下我们的全流程框架 LMFlow 。

29d5dc40-0f2b-11ee-962d-dac502259ad0.png

LMFlow 框架示意图。

我们主要用一个具体的例子展示如何在两个框架下做RLHF,并且记录下训练过程中我们踩到的主要的坑。这个例子包括完整的SFT,奖励建模和 RLHF, 其中RLHF包括通过 RAFT 算法(Reward rAnked FineTuning)或者TRL-PPO 对齐模型两个部分。为了方便用户,我们已经在 Hugging Face repo 中提供了一个基于 GPT-Neo-2.7B 的奖励模型,因此也可以先跳过奖励建模。

这个例子是基于仅适用于非商业用途的许可的 LLaMA 构建的, 为了使用LLaMA-7B 模型, 大家需要填写前面的 request form。测试的环境是 8 X A100 (40G)。

1.1 环境准备

LMFlow 的安装包中也包含了 TRL, 所以我们只需要按照官方的示例安装 LMFlow 即可。

git clone https://github.com/OptimalScale/LMFlow.git
cd LMFlow
conda create -n lmflow python=3.9 -y
conda activate lmflow
conda install mpi4py
pip install -e .

以上安装自动会把依赖的 PyTorch 等包也一起安装, 除此之外, 我们额外手动安装一下 matplotlib 这个包

1.2 数据集描述

我们使用Dahoas/full-hh-rlhf数据集作为例子,其中每个数据集样本包括一个提示和来自助手的两个回应。特别地,标记为 "chosen" 的回应相对于标记为 "rejected" 的回应更被人类所喜欢。数据集包括 112K 个训练样本和 12.5K 个测试样本。以下是数据集的一个示例样本:

" Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

为了便于训练,我们在字符开头添加 ``###'' 来重新构建提示,以便模型知道要回复。新样本的格式将是:

"###Human: What kind of noises did dinosaurs make? ###Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be ###Human: yes they did ###Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. ###Human: you cant read ###Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

我们在目录 ./data/hh_rlhf 中准备了所有需要使用的所有数据集,需要通过在 LMFlow目录下中运行以下命令获得:

cd data && ./download.sh hh_rlhf && cd -

2 RLHF 之前的准备

在这一节, 我们需要先完成SFT模型以及 reward model 的训练, 这部分我们使用LMFlow 完成。

2.1 SFT

这是数据集/home/usrname/LMFlow/data/hh_rlhf/sft/hh_rlhf_sft.json的一个示例。我们只使用首选回应,因此我们得到 112K 个训练样本。

{"type": "text_only", "instances": [{"text": "###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant: You’re welcome."}, {"text": "###Human: I'm trying to learn about the salam witch trials###Assistant: If you’re looking for books about witchcraft trials, I can recommend some books for you. But you probably mean you’re looking for more historical information about the Salem witch trials in 1692, and specifically about the salam witch trials in 1692?###Human: What are some good books to learn about the salam witch trials###Assistant: What would you like to learn?  If you’re interested in history, one of the best books is The Witch Persecutions in Colonial America: A History.  If you’re interested in witchcraft as a cultural phenomenon, you might enjoy two excellent books: Religion and the Decline of Magic: Studies in Popular Beliefs in Sixteenth- and Seventeenth-Century England by Keith Thomas and Magic, Witchcraft, and the Otherworld: An Anthropology of Superstition by Jack Goody.  If you’re interested in history specifically as it relates to religion, you might enjoy The Popish Plot, or Prelates' Plot: A History of the Popish Plot in England, by K. J. Everett."}]}

你可以编辑/scripts/run_finetune.sh并修改参数。我们在这里用 GPT-Neo-2.7B 作为一个例子, 你应当把它换成你获得的 llama-7b 模型的地址。

  • --model_name_or_path: EleutherAI/gpt-neo-2.7B

  • --dataset_path: ${project_dir}/data/hh_rlhf/sft

  • --output_dir: the path you want to store the sft model

  • --num_train_epochs: 1

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: 根据你的GPU资源调整。

  • exp_id: hh_rlhf_llama_sft

你可以编辑/scripts/run_finetune.sh并修改参数。我们在这里用 GPT-Neo-2.7B 作为一个例子。

然后,我们可以运行以下命令来执行 SFT。

./scripts/run_finetune.sh

你还可以通过以下命令使用 lora 训练,但还需要通过编辑run_finetune_with_lora.sh设置 model_name_or_path 和 dataset。

./scripts/run_finetune_with_lora.sh

下面这个损失图像示例中我们设了 epoch 为4, 但是提前停止并使用一个epoch结束的模型作为SFT模型, 此外我们的logging step 设置为了20, 所以整体看起来会比较平滑

2a010802-0f2b-11ee-962d-dac502259ad0.png

SFT 模型训练曲线, 这个例子截取了1.6个epoch 的训练曲线。

在我的例子中, 得到的SFT模型存储在/home/usrname/LMFlow/output_models/hh_rlhf_llama_sft/checkpoint-1271

2.2 Reward Modeling

我们首先按照 InstructGPT 论文的过程:https://arxiv.org/abs/2203.02155使用 HH-RLHF 数据集训练一个奖励模型,其中包括:

  • 监督微调 (SFT);

  • 通过比较数据集进行奖励建模。

由于PPO有较大的内存压力, 后续实验证明在这个例子的设置里, TRL 的实现无法同时载入7B的RM与7B的训练模型, 因此我们选择使用GPT-Neo-2.7B作为我们的RM。其中监督微调与2.1节是完全一致的, 我们就略去了。

奖励建模涉及根据给定提示对回应进行排序,通常需要人类标注。幸运的是,在这个例子中,我们已经为每个提示提供了两个回应,其中一个优于另一个。为了创建一个新数据集,我们修改原始数据集,包括两个key: "positive" 和 "negative",其中前者是首选回应。

这是数据集LMFlow/data/hh_rlhf/rm/hh_rlhf_rm_training.json的一个示例,其中包含两个样本, 完整的数据集包含了Dahoas/full-hh-rlhf数据集112K 的训练集和12.5K的测试集。

{"instances": [
    {"positive": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: Glad to be helpful.  I’m always happy to assist.", "negative": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: You’re welcome. I’m glad to help. Is there anything else you want to know?"}, 
    {"positive": "###Human: how do i teach my kids to make forts at home?###Assistant: Do you mean you want to teach them to build a simple structure they can hide in? Or something more elaborate?  I don’t know exactly what you mean by a “fort”, but maybe this gives you some ideas?", "negative": "###Human: how do i teach my kids to make forts at home?###Assistant: There are many ways to make forts!  The simplest might be to just pile up all the furniture in one room.  This works if you have a single room in your house, although it might still be a bit challenging to build a taller and sturdier fort this way.  The best way to build forts is to start with something easy, and then have fun exploring how you can improve it."}
]}

你可以编辑./scripts/run_reward_modeling.sh脚本以从上一步的 sft 模型开始训练,并更新 "dataset_path" 以改变使用所需的数据集。我们默认使用LoRA进行训练, 因为我们在多个模型上进行了对比, LoRA由于优化上的优势能节省大量的时间, 并且full training 得到的模型并不能提供明显更有优势的RM, 我们用到的参数如下

  • --model_name_or_path: /home/usrname/LMFlow/output_models/hh_rlhf_rm_sft_gptneo_2_7B/checkpoint-1659

  • --dataset_path: ${project_dir}/data/hh_rlhf/rm/hh_rlhf_rm_training.json

  • --output_dir: the path you want to store the reward model

  • --num_train_epochs: 1

  • --learning_rate: 3e-5

  • --per_device_train_batch_size: adjust according to your GPU memory source.

  • --eval_steps: 400

  • --validation_split_percentage: 10

其中我们会自动使用数据集最后的百分之十样本对RM测试, 注意这里使用的数据集是原数据集中的training set + test set, 所以最后的一部分数据集并没有被模型见到过。在这个例子里, validation_split_percentage不应设大于15, 否则会有一部分SFT中用到的样本被使用进测试集 这些数据集的处理都实现在/examples/run_reward_modeling.py中, 如果你想使用你自己的数据集进行训练RM, 可以在这里根据你的需求进行修改。最后, 我们使用下面的代码进行训练

./scripts/run_reward_modeling.sh

下面是GPT-Neo-2.7B 与 LLaMA-7B 模型训练过程中的 evaluation loss 与 evaluation accuracy 图。

2a3423ea-0f2b-11ee-962d-dac502259ad0.png

奖励模型训练中的evaluation曲线。

我们得到的一些RM 示例

Model Eval Accuracy Remarks
LLaMA-7B 79.52% -
LLaMA-7B 71.64% RM from LLaMA without SFT
GPT-NEO-2.7B 69.24% -
GPT-NEO-1.3B 65.58% Only trained on 10000 samples

可以看到一般来说, 更大的模型的准确率也要更高, 但是因为TRL-PPO会爆OOM的问题 (根据一个同学的反馈, 7B+7B 训练 trlx 的实现也一样是会爆OOM), 我们选择使用2.7B的模型。值得注意的是, 即使是LLaMA-7B模型的准确率也只能达到80%左右, 并且得到的RM很可能无法检测到一些我们所不希望有的pattern (例如重复)并仍然给一个比较高的reward。总而言之, 现在这种做分类得到的奖励模型, 仍然是有很大缺陷的。

最后, 因为我们得到的模型是low-rank 的 LoRA adapter, 我们需要使用*./examples/merge_lora.py* 来获得最终的RM模型。

3 RAFT Alignment

原始论文:RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

3.1 Algorithms Overview

RAFT想法的起源如下, 之前有很多研究都发现了如果训练RM的数据集直接做SFT, 效果不如先去训练RM, 再用RL进行reward learning。一个解释是后者能够有更多的数据进行训练, 但我们注意到前向产生数据本身并不仅仅是PPO专属的。此外, 当时我们花了很多的时间去调PPO, 发现PPO进行训练有容易OOM, 不稳定, 模型效果不确定的一些问题 (我们会在下一节记录中间踩的各种坑), 另外就是我们很多实验发现在垂直领域SFT可以稳定地给模型带来很大的性能提升, 一个自然的想法就是, reward learning 是否可以使用SFT。

具体而言, 我们每轮希望最终获取 b 个新样本进行训练,

  • 为此我们从prompt集合中选取 b x k 个prompt 并输入给当前的模型获得对应的输出;

  • 之后我们给b x k 个样本计算奖励;

  • 我们选取奖励最高的比例为1/k的样本进行SFT训练;

    • ''top'': 第一种方法是全部样本排序选取;

    • ''local'': 第二种方法是每个prompt 重复k 次, 并从这k个样本中选取最高奖励的样本;

    • 第一种会高效一些, 但是在一些场景 (例如这个例子里的实验) 下跨prompt的对比没有意义, 局部的排序会更加合理一些。

  • 新的一轮开始。

这里我们只使用了模型输出的一小部分数据进行训练, 这对forward 运算是坏的, 而对backward 运算是好的。我们观察到, 在我们基于deepspeed的实现下, forward 的batch size 可以开到 backward 的五倍左右, 所以我们认为一次推理的代价应该相对会小一些。

3.2 例子

我们使用之前得到的LLaMA-7B-SFT模型进行训练来作为一个例子, 我们希望记录一个具体的实验过程来说明其中的一些坑, 所以下面会有很多冗余和失败的尝试。

数据准备

我们的训练prompt集合就是Dahoas/full-hh-rlhf训练集中的112K样本去掉回复, 例如:

 "###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant:"

我们额外从测试集里抽出2K用以测试。然而当我们使用这个prompt 集合进行 TRL-PPO的训练的时候 (所以后面为了fair comparison我们重做了实验, 泪目), 我们发现代码能够跑得起来, 但是在第二个epoch总是会爆OOM。Debug 良久之后发现原因是有一些prompt长度很长, 加上我们生成文本也比较长, TRL-PPO需要的memory和路径长度正相关, 因此我们只使用 token 数 < 256 的prompt, 最终得到82147个prompts。

测试LLaMA-7B-SFT

我们首先测试了SFT模型, 发现模型针对一个对话历史会回复多轮的自问自答, 为此我们将生成的回复用``###Human'' 进行截断:

def _clean_text(self, text):
    split_text = [x for x in text.split("###Human") if x]
    return split_text[0].strip().strip("#")

在LMFlow中, 使用的RM在*/LMFlow/examples/raft_align.py* 被指定, 如果你使用的奖励模型是按第二节的方法训练出, 你只给定它所在的本地地址或者 Hugging Face repo id:

reward_model_or_path: Optional[str] = field(
    default="weqweasdas/hh_rlhf_rm",
    metadata={
        "help": (
            "reward model name (huggingface) or its path"
        ),
    },
)

但是如果你的RM是一般性的, 例如 Hugging Face 上的一些分类器, 你可能还需要略微修改``get_reward_function'' 函数。

3.2.1 第一次训练

我们在LMFlow目录下, 使用如下的命令和参数进行训练:

./scripts/run_raft_align.sh
  • --model_name_or_path: /home/usrname/output_models/hh_rlhf_llama-sft (the model get from sft step, adjusted according your setup)

  • --dataset_path:${project_dir}/data/hh_rlhf/rlhf/rlhf_prompt

  • --output_dir: /home/usrname/output_models/hh_rlhf_raft_align

  • --num_train_epochs: 4

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: adjust according to your GPU memory source.

  • --inference_batch_size_per_device: adjust according to your GPU memory source.

  • --num_raft_iteration 20

  • --top_reward_percentage 0.125; (也就是1/8)

  • --raft_batch_size 1024 (每轮最终有1024个样本用来训练)

  • --output_min_length 126

实验运行地很顺利,训练奖励从约2.7提高到3.4,在我们的训练中, 我们监测了模型输出的一些多样性指标,我们注意到部分指标(例如distinct-2)在训练中显著下降,从0.39降至0.22。虽然有一些研究说明alignment tax 导致RLHF 模型的指标往往会变差 (作为human preference 上变好的代价), 但是这样大幅度的下降仍然是不同寻常的。为此, 我们检查了每个迭代时我们生成的样本,并发现如同SFT的测试, 在第一次迭代中,初始检查点的响应中偶尔会包含# (3%左右的样本),而我们的奖励函数无法检测到随机的#,这意味着包含#的响应也可能具有很高的奖励并被选入训练集。随后,情况变得越来越糟糕,最终有一半的响应包含嘈杂的#符号。

3.2.2 第二次训练

为了解决上述问题, 我们修改了代码并检测每个样本的回复是否含有冗余的#, 如果是, 则手动修改为一个低奖励。同时, 在当前的实现中, 我们会输出每一轮用以SFT的数据集用以监测整个训练过程。修改代码之后, 我们得到了如下的奖励曲线 (注意我们在测试的时候会使用比较低的temperature, 所以测试的奖励要高一些):

2a70e24e-0f2b-11ee-962d-dac502259ad0.png

RAFT的训练奖励曲线图, 横坐标表示一次 1) 数据生成 + 2) reward计算与样本排序 + 3) 一轮SFT。

其中横坐标代表的是一个raft的迭代, 包括 1) 数据生成 2) 数据排序 3) 以及在选出的数据集上进行一轮SFT。在我们的例子中, 每一轮会生成8192个样本, 并有1024个样本被使用去SFT。我们可以看到在训练的开始, 用以训练的数据集中的样本 (黄线)比我们模型自身的奖励要高得多, 而在这个小数据集上SFT之后, 模型的奖励开始上升 (绿线和蓝线), 而这反过来也改善了收集到的训练数据 (黄线也在上升)。在 8 x A100 (40G) 上进行如上训练大约需要三个小时。

最终获得的模型在奖励和多样性度量方面都表现良好,我们建议有兴趣的读者参考原始论文了解详细信息。然而,这更像是我们旅程的起点, 我们在最后一部分的讨论里对结果进行进一步的讨论, 在此之前, 我们先记录一下如何使用TRL-PPO进行实验。

4 TRL-PPO Alignment

LMFlow 安装过程中也会把TRL安装所以我们可以直接开始实验,在三个月之前想跑起来TRL需要手动修复几个小bug, 这几天拉了最新版本试验了一下似乎都已经修复了。

数据准备

我们首先修改 TRL-PPO 提供的script里的数据集准备, 注意我们将 TRL-PPO 的script 放在 LMFlow/examples中, 否则你需要稍微修改一下下面数据集的位置:

def build_dataset(config, tokenizer, dataset_name="./data/hh_rlhf/rlhf/rlhf_prompt/prompt.json"):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """

    ds = load_dataset("json", data_files=dataset_name, split="train")['instances'][0]
    texts = [sample['text'] for sample in ds]
    from datasets import Dataset
    ds = Dataset.from_dict({
        "text":texts,
    })
    
    
    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["text"])[:]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds = ds.filter(lambda x: len(x["input_ids"]) <= 256)
    ds.set_format(type="torch")
    print(len(ds))
    return ds

注意这里我们筛选了prompt 数据集, 只保留长度为256个token以内的, 否则过长的文本会导致OOM的错误。

超参数调整

PPO比较依赖于超参数, 不过我几个实验调下来的感觉是TRL默认的参数效果已经很不错了, 即使仔细调整学习率等等也很难获得很大的提升, 需要改的超参数包括:

  • batch_size: 1024/n_gpu, 在我们的设置下为128;

  • mini_batch_size: 一个有意思的发现是PPO的更新batch size 通常要比SFT小不少, 导致它会慢得多, 但不太确定是因为代码实现问题还是PPO本身需要的中间变量比较多的原因;

  • gradient_accumulation_steps: 1

除此之外, 比较关键的在于KL的权重的设置, 我最开始的想法就是简单的去搜, 结果从0.1, 0.05, 0.01 跑了好几轮都不能收敛 (reward 上升一阵后突然垮掉, 或者没有明显的reward 上升)。最后我的选择是先将KL的系数设为0, 然后去修改TRL的ppo_trainer 中的compute_rewards 函数, 打印出这个情况下的KL估计:

    def compute_rewards(
        self,
        scores: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        ref_logprobs: torch.FloatTensor,
        masks: torch.LongTensor,
    ):
        """
        Compute per token rewards from scores and KL-penalty.

        Args:
            scores (`torch.FloatTensor`):
                Scores from the reward model, shape (`batch_size`)
            logprobs (`torch.FloatTensor`):
                Log probabilities of the model, shape (`batch_size`, `response_length`)
            ref_logprobs (`torch.FloatTensor`):
                Log probabilities of the reference model, shape (`batch_size`, `response_length`)
        """
        cnt = 0
        rewards, non_score_rewards = [], []
        for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
            # compute KL penalty (from difference in logprobs)
            kl = logprob - ref_logprob
            non_score_reward = -self.kl_ctl.value * kl
            non_score_rewards.append(non_score_reward)
            reward = non_score_reward.clone()
            last_non_masked_index = mask.nonzero()[-1]

            # reward is preference model score + KL penalty
            reward[last_non_masked_index] += score
            rewards.append(reward)
            if cnt < 20:
                print(torch.sum(kl))
                cnt += 1
        return torch.stack(rewards), torch.stack(non_score_rewards)

最终发现在reward曲线的后期, KL偏移最高能达到五六百之多, 最后决定设一个比较小的KL=0.001 (和paper [1] 一致)。在一些实验里我们有发现一个比较小的学习率在perplexity指标上会明显好一些。而值得注意的是[1]中设置的学习率要小得多, 文章中汇报的最大KL偏移也只有一两百左右, 我有尝试过5-e6的学习率, 结论是训练变得缓慢了很多 (需要一天多的时间进行训练), 但是并没有对KL偏移有明显改善,由于时间所限, 没有尝试更低的学习率了, 暂时不确定是超参数的设置问题还是TRL-PPO和 [1] 中实现的差异。我建议始终采样一些样本查看它们的KL估计以监测训练是否正常。

此外, 模型有时候回复会过短, 在ppo_trainer中有如下检查会报错, 一个办法是直接注释掉这个报错, 一个办法是对样本进行检测, 丢弃掉回复太短的样本, 两个方法我都试过似乎效果差不多。

def batched_forward_pass(
    ......
    
    if len(logprobs[j, start:end]) < 2:
    raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
    
    ......

需要指出的是, 由于我们需要估计KL, 在TRL-PPO中, 我们不能随意调整生成的设置, 否则将很可能影响KL的估计:

generation_kwargs = {
    # "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000,
}

例如, 为了解决上面的回复太短的问题, 我们有尝试设置最短输出长度来强制模型输出更长的回复, 但是设置之后, 我们发现接近一半的KL估计都变为了负数。

训练

在PPO的训练中也会有模型自问自答生成多轮回复的问题, 并且在这个情况下是训不出来的, 所以我们也相应的去截断整个输出, 需要注意的是我们需要对应截断返回来的response_tensors:

output_min_length = 64
output_max_length = 128
output_length_sampler = LengthSampler(output_min_length, output_max_length)
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 1}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    with torch.no_grad():
        response_tensors = ppo_trainer.generate(
            query_tensors, 
            batch_size=1, ## adjust according to your memory source 
            return_prompt=False, 
            length_sampler=output_length_sampler, 
            **generation_kwargs)

    full_responses = tokenizer.batch_decode(response_tensors)
    clean_texts = [clean_text(tmp_text) for tmp_text in full_responses]
    clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
    lengths = [len(clean_tensor) for clean_tensor in clean_response_tensors]

    response_tensors = [response_tensors[i][:np.max([lengths[i]-2, 1])] for i in range(len(response_tensors))]

    batch["response"] = clean_texts

    texts_for_rewards = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts_for_rewards, **sent_kwargs)
    rewards = [output[0]["score"] for output in pipe_outputs]
    

在进行多番调参之后, 得到的PPO模型有一些奇怪的pattern, 首先PPO模型也会在输出里掺入大量随机的#, 因此需要和RAFT的训练一样加入一个检测来丢弃掉这些样本或者手动给予一个比较负面的奖励, 加入之后, PPO模型输出随机#的现象得到了缓解, 结果PPO开始复读 ``:) '' 这样一个颜表情了, 我试着再次惩罚这样一种在回复中加入大量 :) 的行为, 于是PPO开始复读 ;) 了。。。好在后面两个问题不算太严重,比例比较低,还能接受,由于DRL本身是比较黑箱的方法, 我们不太能直接得知模型倾向于生成这些颜表情的原因, 但我们猜测可能是RM对这类颜表情比较喜好, 使得PPO 利用了这种RM的缺陷。

TRL-PPO默认会使用一个随机的生成长度, 我们尝试了固定128输出长度和随机从[64, 128] 中抽取输出长度两种方式, 发现在其他设置合适的情况下都能学到比较好的reward, 但是后者似乎对于避免输出重复有一定帮助,最终得到的模型输出观感要更好一些。

PPO主要在调参上需要花费比较多的时间, 当参数合适时, 一次训练大概需要8~12个小时。

5 讨论

我们在下面展示一些随机抽样的例子,可以看到不管是 PPO 和 RAFT 都明显改变了模型回复的风格。整体而言, RAFT-aligned 模型通常倾向于用更多的细节回复,PPO 模型会更加礼貌而积极一些, 而 SFT 模型似乎不够 helpful, 很多时候没有按照指示给予建议。同时, 我们也观察到 PPO 会偶尔输出一些无意义的符号, RAFT 的回复有时候冗余的词有一些多。

我们认为这是因为奖励模型无法完全刻画一个回复的质量, 而 PPO 和 RAFT 都在某种程度上利用了奖励模型的这种不完美来获得高奖励。显然, 这只是 RLHF 探索的起始点, 我们还有许多改进的空间。为了进一步提高模型性能,例如, 我们可以改进奖励模型(例如使用 LLaMA-7B-RM), 我们也可以尝试一些更先进的生成策略来提升生成文本的质量 (例如 contrastive search, 见https://zhuanlan.zhihu.com/p/629920420)。同时,请查看我们的 LMFlow 框架,以获取更多 LLMs 的乐趣:

OptimalScale/LMFlow: An Extensible Toolkit for Finetuning and Inference of Large Foundation Models. Large Model for All. (github.com)
https://github.com/OptimalScale/LMFlow

(以下图片由表格转换而来,为了显示方便,Prompt 中的 ###替换成了换行,并以粗体呈现)

2aa5ee58-0f2b-11ee-962d-dac502259ad0.png

2b130808-0f2b-11ee-962d-dac502259ad0.png

2b51d826-0f2b-11ee-962d-dac502259ad0.png

2ba70fbc-0f2b-11ee-962d-dac502259ad0.png

2bfb7ba6-0f2b-11ee-962d-dac502259ad0.png

[1] Training a helpful and harmless 326 assistant with reinforcement learning from human feedback


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 框架
    +关注

    关注

    0

    文章

    403

    浏览量

    17510
  • 模型
    +关注

    关注

    1

    文章

    3261

    浏览量

    48912
  • 数据集
    +关注

    关注

    4

    文章

    1208

    浏览量

    24736

原文标题:RLHF 实践中的框架使用与一些坑 (TRL, LMFlow)

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    使用MDK5时出现过的一些error踩过的分享

    使用MDK5时出现过的一些error踩过的分享
    发表于 12-17 07:49

    分享一些嵌入式系统编程内存操作相关的避指南

    在嵌入式系统的编程,内存操作是我们常用到的,但往往也是易错的地方,怎么避免呢,今天给大家分享一些相关的避指南。数据指针...
    发表于 12-17 07:18

    介绍STM32入一些问题及资料

    介绍STM32入一些问题及资料
    发表于 01-19 06:11

    mpu6050和stm32的一些

    文章目录引言简述下mpu6050和stm32的一些吧MPU6050 I2C读写MPU6050 初始化读取内部温度传感器读取陀螺仪DMP的移植读取三轴角度引言最近玩了IMU模块,看了很多的博客
    发表于 02-10 07:35

    总结一些在编写单片机程序及其他相关实践中学到的C语言技巧

    文章内容  该文章主要是总结一些在编写单片机程序及其他相关实践中学到的C语言技巧,面向读者应具有C语言基础。1. 位运算2. 宏定义3. 字符串4. 数据类型4.1 有符号无符号4.2 布尔类型 变量的类型staticconst...
    发表于 02-24 06:25

    ADμC812 芯片实践中几点注意

    ADμC812 芯片实践中几点注意
    发表于 05-16 14:16 10次下载

    EDA 威廉希尔官方网站 在教学实践中的应用2

    EDA 威廉希尔官方网站 在教学实践中的应用2 摘 要: EDA 在电子威廉希尔官方网站 教学实践中的应用是现代教育发展的种趋势, 本文通过数字式测温仪这综合性设计实例,比较全面地说明
    发表于 12-07 13:50 0次下载

    用实例引起大家在嵌入式做项目时对一些问题的关注

    虽然没有做过产业调查,但从我所见和所招聘人员,从事嵌入式行业的工程师,要么缺乏理论知识,要么缺乏实践经验。很少两者兼备的。究其原因,还是中国的大学教育的问题。这里不探讨这个问题,避免口水战。我想列出我实践中的几个例子。引起大家在嵌入式
    的头像 发表于 01-05 11:31 3913次阅读
    用实例引起大家在嵌入式<b class='flag-5'>中</b>做项目时对<b class='flag-5'>一些</b>问题的关注

    无人机航拍在电视新闻实践中的应用与影响

    民用航拍无人机威廉希尔官方网站 的成熟,让无人机航拍在新闻传播领域的应用成为了可能。航拍无人机作为种全新的新闻采访设备,在电视新闻实践中表现出了定竞争力,囿于行业管理、威廉希尔官方网站 缺陷和无人机飞手等因素,无人机航拍在电视新闻
    发表于 12-22 08:01 4926次阅读

    剖析智能制造关于“轻与重”的实践中的误区

    智能制造无疑是建设制造强国重中之重的核心策略。但还有很多制造企业存在不少关于“轻与重”的认识与实践中的误区!
    的头像 发表于 02-20 10:33 3419次阅读

    光纤涂覆机在科研及工程实践中详细应用步骤(图文)

    光纤涂覆机在科研及工程实践中详细应用步骤(图文)国产光纤涂覆机
    发表于 02-27 14:30 752次阅读

    埋点实践过程遇到的一些问题

    埋点本身现在已经有太多的集成解决方案,神策、诸葛IO、GIO,但是在实践的过程仍然还是会碰都很多问题,这些问题都是躺过的。 01 梳理当前业务,未来业务发展问题,目的是给埋点预留空间 ① 业务
    的头像 发表于 10-22 16:33 1735次阅读

    关于蓝桥杯单片机开发板矩阵键盘的一些

    关于蓝桥杯单片机开发板矩阵键盘的一些
    发表于 11-23 17:36 2次下载
    关于蓝桥杯单片机开发板矩阵键盘的<b class='flag-5'>一些</b><b class='flag-5'>坑</b>

    深度学习框架pytorch入门与实践

    深度学习框架pytorch入门与实践 深度学习是机器学习个分支,它使用多层神经网络对大量数据进行学习,以实现人工智能的目标。在实现深度学习的过程
    的头像 发表于 08-17 16:03 1617次阅读

    科研及工程实践中光纤涂覆机详细操作步骤(图文)

    电子发烧友网站提供《科研及工程实践中光纤涂覆机详细操作步骤(图文).pdf》资料免费下载
    发表于 11-02 15:07 0次下载