0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看威廉希尔官方网站 视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

project复现过程踩到坑对应的解决方案

深度学习自然语言处理 来源:深度学习自然语言处理 作者:深度学习自然语言 2022-08-19 11:09 次阅读

最近做的一个 project 需要复现 EMNLP 2020 Findings 的 TinyBERT,本文是对复现过程对踩到坑,以及对应的解决方案和实现加速的一个记录。

1. Overview of TinyBERT

BERT 效果虽好,但其较大的内存消耗和较长的推理延时会对其上线部署造成一定挑战

内存消耗方面,一系列知识蒸馏的工作,例如 DistilBERT[2]、BERT-PKD[3] 和 TinyBERT 被提出来用以降低模型的参数(主要是层数)以及相应地减少时间;

推理加速方面,也有 DeeBERT[4]、FastBERT[5] 及 CascadeBERT[6] 等方案提出,它们动态地根据样本难度进行模型的执行从而提升推理效率。其中较具备代表性的是 TinyBERT,其核心框架如下:

ca3400ec-1ea7-11ed-ba43-dac502259ad0.png

分为两个阶段:

General Distillation:在通用的语料,例如 BookCorpus, EnglishWiki 上进行知识蒸馏;目标函数包括 Transformer Layer Attention 矩阵以及 Layer Hidden States 的对齐;

Task Distillation:在具体的任务数据集上进行蒸馏,进一步分成两个步骤:

Task Transformer Disitllation: 在任务数据集上对齐 Student 和已经 fine-tuned Teacher model 的 attention map 和 hidden states;

Task Prediction Distillation:在任务数据集上对 student model 和 teacher model 的 output distritbuion 利用 KL loss / MSE loss 进行对齐。

TinyBERT 提供了经过 General Distillation 阶段的 checkpoint,可以认为是一个小的 BERT,包括了 6L786H 版本以及 4L312H 版本。而我们后续的复现就是基于 4L312H v2 版本的。

值得注意的是,TinyBERT 对任务数据集进行了数据增强操作:通过基于 Glove 的 Embedding Distance 的相近词替换以及 BERT MLM 预测替换,会将原本的数据集扩增到 20 倍。而我们遇到的第一个 bug 就是在数据增强阶段。

2. Bug in Data Augmentation

我们可以按照官方给出的代码对数据进行增强操作,但是在 QNLI 上会报错:

ca6174dc-1ea7-11ed-ba43-dac502259ad0.png

造成数据增强到一半程序就崩溃了,为什么呢?

很简单,因为数据增强代码 BERT MLM 换词模块对于超长(> 512)的句子没有特殊处理,造成下标越界,具体可以参考 #Issue50:error occured when apply data_augmentation on QNLI and QQP dataset[7]。

在对应的函数中进行边界的判断即可:

ca73213c-1ea7-11ed-ba43-dac502259ad0.png

3. Acceleration of Data Parallel

当我们费劲愉快地完成数据增强之后,下一步就是要进行 Task Specific 蒸馏里的 Step 1,General Distillation 了。

对于一些小数据集像 MRPC,增广 20 倍之后的数据量依旧是 80k 不到,因此训练速度还是很快的,20 轮单卡大概半天也能跑完。但是对于像 MNLI 这样 GLUE 中最大的数据集(390k),20 倍增广后的数据集(增广就花费了大约 2 天时间),如果用单卡训练个 10 轮那可能得跑上半个月了,到时候怕不是黄花菜都凉咯。

3.1 多卡训练初步尝试

遂打算用多卡训练,一看,官方的实现就通过 nn.DataParal lel 支持了多卡。好嘛,直接 CUDA_VISIBLE_DEVICES="0,1,2,3" 来上 4 块卡。不跑不知道,一跑吓一跳:

加载数据(tokenize, padding )花费 1小时;

好不容易跑起来了,一开 nvidia-smi 发现 GPU 的利用率都在 50% 左右;

再一看预估时间,大约 21h 一轮,10 epoch 那四舍五入就是一个半礼拜。

好家伙,这我还做不做实验了?

3.2 DDP 替换 DP

这时候就去翻看 PyTorch 文档,发现 PyTorch 现在都不再推荐使用 nn.DataParallel 了,为什么呢?主要原因在于:

DataParallel 的实现是单进程的,每次都是有一块主卡读入数据再发给其他卡,这一部分不仅带来了额外的计算开销,而且会造成主卡的 GPU 显存占用会显著高于其他卡,进而造成潜在的 batch size 限制;

此外,这种模式下,其他 GPU 算完之后要传回主卡进行同步,这一步又会受限于 Python 的线程之间的 GIL(global interpreter lock),进一步降低了效率。

此外,还有多机以及模型切片等 DataParallel 不支持,但是另一个 DistributedDataParallel 模块支持的功能。

所以得把原先 TinyBERT DP(DataParallel)改成 DDP(DistributedDataParallel)。把 DP 改成 DDP 可以参考知乎-当代研究生需要掌握的并行训练技巧[8]。核心的代码就是做一下初始化,以及用 DDP 替换掉 DP

cabdeab4-1ea7-11ed-ba43-dac502259ad0.png

然后,大功告成,一键启动:

cafeb27e-1ea7-11ed-ba43-dac502259ad0.png

启动成功了吗?模型又开始处理数据….

One hours later,机器突然卡住,程序的 log 也停了,打开 htop 一看:好家伙,256G 的内存都满了,程序都是 D 状态,这是咋回事?

4. Acceleration of Data Loading

我先试了少量数据,降采样到 10k,程序运行没问题, DDP 速度很快;我再尝试了单卡加载,虽然又 load 了一个小时,但是 ok,程序还是能跑起来,那么,问题是如何发生的呢?

单卡的时候我看了一眼加载全量数据完毕之后的内存占用,大约在 60G 左右,考虑到 DDP 是多进程的,因此,每个进程都要独立地加载数据,4 块卡 4个进程,大约就是 250 G 的内存,因此内存爆炸,到后面数据的 io 就卡住了(没法从磁盘 load 到内存),所以造成了程序 D 状态。

看了下组里的机器,最大的也就是 250 G 内存,也就是说,如果我只用 3 块卡,那么是能够跑的,但是万一有别的同学上来开程序吃了一部分内存,那么就很可能爆内存,然后就是大家的程序都同归于尽的局面,不太妙。

一种不太优雅的解决方案就是,把数据切块,然后读完一小块训练完,再读下一块,再训练,再读。咨询了一下组里资深的师兄,还有一种办法就是实现一种把数据存在磁盘上,每次要用的时候才 load 到内存的数据读取方案,这样就能够避免爆内存的问题。行吧,那就干吧,但是总不能从头造轮子吧?

脸折师兄提到 huggingface(yyds) 的 datasets[9] 能够支持这个功能,check 了一下文档,发现他是基于 pyarrow 的实现了一个 memory map 的数据读取[10],以我的 huggingface transformers 的经验,似乎是能够实现这个功能的,所以摩拳擦掌,准备动手。

首先,要把增广的数据 load 进来,datasets 提供的 load_dataset 函数最接近的就是 load_dataset('csv', data_file),然后我们就可以逐个 column 的拿到数据并且进行预处理了。

写了一会,发现总是报读取一部分数据后 columns 数目不对的错误,猜测可能原始 MNLI 数据集就不太能保证每个列都是在的,检查了一下 MnliProcessor 里处理的代码,发现其写死了 line[8] 和 line[9] 作为 sentence_a 和 sentence_b。无奈之下,只能采取最粗暴地方式,用 text mode 读进来,每一行是一个数据,再 split:

cb1adf4e-1ea7-11ed-ba43-dac502259ad0.png

写完这个 preprocess_func ,我觉得胜利在望,但还有几个小坑需要解决s:

map 完之后,返回的还是一个 DatasetDict,得手动取一下 train set;

对于原先存在的列,map 函数并不会去除掉,所以如果不用的列,需要手动 .remove_columns()

在配合 DDP 使用的时候,因为 DistributedSample 取数据的维度是在第一维取的,所以取到的数据可能是个 seq_len 长的列表,里面的 tensor 是 [bsz] 形状的,需要在交给 model 之前 stack 一下:

cb45577e-1ea7-11ed-ba43-dac502259ad0.png

至此,只要把之前代码的 train_data 都换成现在的版本即可。

此外,为了进一步加速,我还把混合精度也整合了进来,现在 Pytorch 以及自带对混合精度的支持,代码量也很少,但是有个坑就是loss 的计算必须被 auto() 包裹住,同时,所有模型的输出都要参与到 loss 的计算,这对于只做 prediction 或者是 hidden state 对齐的 loss 很不友好,所以只能手动再额外计算一项为系数为 0 的 loss 项(这样他参与到训练但是不会影响梯度)。

总结

最后,改版过的代码在我的 GitHubfork[11]版本中,我不要脸地起名为fast_td。实际上,改版后的有点有一下几个:

数据加载方面:第一次加载/处理 780w 大约耗时 50m,但是不会多卡都消耗内存,实际占用不到 2G;同时,得益于 datasets 的支持,后续加载不会重复处理数据而是直接读取之前的 cache;

模型训练方面:得益于 DDP 和 混合精度,在 MNLI 上训增强数据 10 轮,3 块卡花费的时间大约在 20h 左右,提速了 10 倍。

审核编辑:彭静
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 模型
    +关注

    关注

    1

    文章

    3255

    浏览量

    48907
  • project
    +关注

    关注

    0

    文章

    35

    浏览量

    13300
  • 数据集
    +关注

    关注

    4

    文章

    1208

    浏览量

    24730

原文标题:4. Acceleration of Data Loading

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    立体智慧仓储解决方案.#云计算

    解决方案智能设备
    学习电子知识
    发布于 :2022年10月06日 19:45:47

    IAP功能实现过程遇到的

    花了四天时间才把IAP功能做好。其中也遇到许多的,这次把这次IAP功能实现过程遇到的把它分享出来。一开始做iap的时候也是先从网上看别人的实现方法,其中就下载了一套别人的程序,不过主控芯片
    发表于 08-05 07:51

    Linux学习过程踩过的与如何解决踩

    Linux踩记录记录Linux学习过程踩过的与如何解决踩1解决方法:F10进入BIOS使能虚拟化威廉希尔官方网站
    发表于 11-04 08:44

    mongoose开发中遇到的解决方案

    1. 本文不对mongoose的功能作陈述,只记录下自己开发中遇到的,及解决方案。嵌入了mongoose的代码编译通过,在调试运行(gdb)时候,却发生了段错误(Segmentation fault),如下所示:...
    发表于 12-16 06:56

    STC8H8K64U芯片学习过程中遇到的问题及对应解决方案

    STC8H8K64U芯片该怎样进行封装呢?STC8H8K64U芯片学习过程中遇到的问题及对应解决方案
    发表于 12-21 06:59

    分享基于STM32 4x4键盘扫描尝试过程踩到的雷

    解决吗??有,当然有了,那就是矩阵键盘扫描,在查阅许多大神博客、资料后有了点眉目便开始尝试,历经千辛万苦终于弄出来了!那喜悦!那开心!下面给大家分享尝试过程踩到的雷。矩阵键盘扫描原理浏览过多篇文章后决定尝试翻转法来进行矩阵键盘扫描,丢出键盘原理图:四行四列共八个IO口,
    发表于 01-05 07:56

    在RT-Thread开发过程中引入watchdog踩到

    今天在RT-Thread完整版开发过程中引入watchdog,踩到一个,系统一直重启,喂狗一直失败,搞了一天才解决,总结一下。我的RT-Thread完整版系统是最新版4.0.3(截止2020年12
    发表于 02-17 06:05

    记录一个在使用BlackBox中parameter踩到

    踩到在很早之前,曾写过如何在SpinalHDL中例化之前用Verilog/SystemVerilog所写的代码,可参照文章《[SpinalHDL——集成你的RTL代码]》一文。在
    发表于 08-31 14:58

    记录BL808 BSP添加GPIO驱动时踩到的一些解决方案

    该文主要记录为 BL808 BSP 添加 GPIO 驱动时踩到的一些解决方案。这是我第一次对接 RT-Thread BSP 的驱动,整理出本文避免之后踩到同样的
    发表于 02-03 14:36

    光端机在使用过程中遇到的常见问题及对应解决方案

    光端机,就是光信号传输的终端设备,我们在使用的过程中难免会碰到一些问题,接下来杭州飞畅的小编为大家详细列举了光端机在使用过程中遇到的一些常见问题以及对应解决方案,感兴趣的朋友就一起来
    的头像 发表于 09-08 15:35 3661次阅读

    使用Redis时可能遇到哪些「」?

    这篇文章,我想和你聊一聊在使用 Redis 时,可能会踩到的「」。 如果你在使用 Redis 时,也遇到过以下这些「诡异」的场景,那很大概率是踩到」了: 明明一个 key 设置了
    的头像 发表于 04-09 11:19 2317次阅读
    使用Redis时可能遇到哪些「<b class='flag-5'>坑</b>」?

    模型调优和复现算法遇到的一些

    的数据增强方式与代码的实现不一样等。(这些可能发生在开源复现者没有“一比一”复现论文的情况,也可能发生在论文作者自己没有实现的情况)
    的头像 发表于 05-18 15:03 1246次阅读

    RLHF实践中的框架使用与一些 (TRL, LMFlow)

    我们主要用一个具体的例子展示如何在两个框架下做RLHF,并且记录下训练过程中我们踩到的主要的。这个例子包括完整的SFT,奖励建模和 RLHF, 其中RLHF包括通过 RAFT 算法(Reward rAnked FineTuni
    的头像 发表于 06-20 14:36 1952次阅读
    RLHF实践中的框架使用与一些<b class='flag-5'>坑</b> (TRL, LMFlow)

    记录为BL808添加GPIO驱动

    该文主要记录为 BL808 BSP 添加 GPIO 驱动时踩到的一些解决方案。这是我第一次对接 RT-Thread BSP 的驱动,整理出本文避免之后踩到同样的
    的头像 发表于 10-13 11:18 641次阅读

    树莓派Pico Flash驱动踩记录

    树莓派 pico 带有 2MB 的 Flash 资源,以下是我基于官方 Pico C/C++ SDK 对接 Flash 驱动时踩到的一些和解决办法。
    的头像 发表于 10-20 11:44 1547次阅读