Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

text.md 11 KB

You have to be logged in to leave a comment. Sign In

Bert & Transformer源码详解

参考论文

https://arxiv.org/abs/1706.03762

https://arxiv.org/abs/1810.04805

在本文中,我将以run_classifier.py以及MRPC数据集为例介绍关于bert以及transformer的源码,另外,本文在最后一个部分详细讲解了如何从0到1来跑自己的第一个bert模型。

章节

Demo传参

首先大家拿到这个模型,管他什么原理,肯定想跑起来看看结果,至于预训练模型以及数据集下载。任何时候应该先看官方教程,官方代表着权威,更容易实现,如果遇到问题可以去issues和stackoverflow看看,再辅以中文教程,一般上手就不难了,这里就不再赘述了。

先从Flags参数讲起,到如何跑通demo。

拿到源码不要慌张,英文注释往往起着最关键的作用,另外阅读源码详细技巧可以看源码技巧

"Required Parameters"意思是必要参数,你等会执行时必须向程序里面传的参数。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue

python run_classifier.py \
  --task_name=MRPC \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/

这是官方给的示例,这个将两个文件夹加入了系统路径,本人Ubuntu18.04加了好像也找不到,所以建议将那些文件路径改为绝对路径

task_name --> 这次任务的名称
do_train --> 是否做fine-tune
do_eval --> 是否交叉验证
do_predict --> 是否做预测
data_dir --> 数据集的位置
vocab_dir --> 词表的位置(一般bert模型下好就能找到) 
bert_config --> bert模型参数设置
init_checkpoint --> 预训练好的模型
max_seq_length --> 一个序列的最大长度
output_dir --> 结果输出文件(包括日志文件)
do_lower_case --> 是否小写处理(针对英文)

其他的字面意思

跑不动?

有些时候发现跑demo的时候会出现各种问题,这里简单汇总一下

1.No such file or directory!

这个意思是没找到,你需要确保你上面模型和数据文件的路径填正确就可解决

2.Memory Limit

因为bert参数量巨大,模型复杂,如果GPU显存不够是带不动的,就会出现上图的情形不断跳出。

解决方法

  • 把batch_size,max_seq_length,num_epochs改小一点
  • 把do_train直接false掉
  • 使用优化bert模型,如Albert,FastTransformer

经过本人实证,把参数适当改小参数,如果还是不行直接不做fine-tune就好,这对迅速跑通demo的人来说最有效。


数据篇

这是很多时候我们自己跑别的任务最为重要的一章,因为很多时候模型并不需要你大改,人家都已经给你训练好了,你在它的基础上进行优化就好了。而数据如何读入以及进行处理,让模型可以训练是至关重要的一步。


数据读入

简单介绍一下我们的数据,第一列为Quality,意思是前后两个句子能不能匹配得起来,如果可以即为1,反之为0。第二,三两列为ID,没什么意义,最后两列分别代表两个句子。

接下来我们看到DataProcessor类,(有些类的作用仅仅是初始化参数,本文不作讲解)。这个类是父类(超类),后面不同任务数据处理类都会继承自它。它里面定义了一个读取tsv文件的方法。

首先会将每一列的内容读取到一个列表里面,然后将每一行的内容作为一个小列表作为元素加到大列表里面。


数据处理

因为我们的数据集为MRPC,我们直接跳到MrpcProcessor类就好,它是继承自DataProcessor。

这里简要介绍一下os.path.join。

我们不是一共有三个数据集,train,dev以及test嘛,data_dir我们给的是它们的父目录,我们如何能读取到它们呢?以train为例,是不是得"path/train.tsv",这个时候,os.path.join就可以把两者拼接起来。

这个意思是任务的标签,我们的任务是二分类,自然为0&1。

examples最终是列表,第一个元素为列表,内容图中已有。


词处理
读取数据之后,接下来我们需要对词进行切分以及简单的编码处理


切分

label_list前面对数据进行处理的类里有get_labels参数,返回的是一个列表,如["0","1"]。

想要切分数据,首先得读取词表吧,代码里面一开始创造一个OrderedDict,这个是为什么呢?

在python 3.5的时候,当你想要遍历键值对的时候它是任意返回的,换句话说它并不关心键值对的储存顺序,而只是跟踪键和值的关联程度,会出现无序情况。而OrderedDict可以解决无序情况,它内部维护着一个根据插入顺序排序的双向链表,另外,对一个已经存在的键的重复复制不会改变键的顺序。

需要注意,OrderedDict的大小为一般字典的两倍,尤其当储存的东西大了起来的时候,需要慎重权衡。

但是到了python 3.6,字典已经就变成有序的了,为什么还用OrderedDict,我就有些疑惑了。如果说OrderedDict排序用得到,可是普通dict也能胜任,为什么非要用OrderedDict呢?

在tokenization.py文件中提供了三种切分,分别是BasicTokenizer,WordpieceTokenizer和FullTokenizer,下面具体介绍一下这三者。

在tokenization.py文件中遍布convert_to_unicode,这是用来转换为unicode编码,一般来说,输入输出不会有变化。

这个方法是用来替换不合法字符以及多余的空格,比如\t,\n会被替换为两个标准空格。

接下来会有一个_tokenize_chinese_chars方法,这个是对中文进行编码,我们首先要判断一下是否是中文字符吧,_is_chinese_char方法会进行一个判断。

如果是中文字符,_tokenize_chinese_chars会将中文字符旁边都加上空格,图中我也有引例注释。

whitespace_tokenize会进行按空格切分。

_run_strip_accents会将变音字符替换掉,如résumé中的é会被替换为e。

接下来进行标点字符切分,前提是判断是否是标点吧,_is_punctuation履行了这个职责,这里不再多说。

以上便是BasicTokenizer的内容了。

接下来是WordpieceTokenizer了,其实这个词切分是针对英文单词的,因为汉字每个字已经是最小的结构,不能进行切分了。而英文还可以进行切分,英文有不同语态,如loved,loves,loving等等,这个时候WordpieceTokenizer就能发挥作用了。

  • 遍历一个英文单词里面的小结构,如果发现在词表里找到,就把这个切掉
  • 对未被切分的部分继续进行步骤一,直至所有都被切分干净,注意除了第一个,其他的前面都要加上"##"

下面有个gif可以直观显示,来源

最后是FullTokenizer,这个是两者的集成版,先进行BasicTokenizer,后进行WordpieceTokenizer。当然了,对于中文,就没必要跑WordpieceTokenizer。

下面简单提一下convert_by_vocab,这里是将具体的内容转换为索引。

以上就是切分了。


词向量编码

刚刚对数据进行了切分,接下来我们跳到函数convert_single_example,进一步进行词向量编码。

这里是初始化一个例子。input_ids 是等会把一个一个词转换为词表的索引;segment_ids代表是前一句话(0)还是后一句话(1),因为这还未实例化,所以is_real_example为false。

此处tokenizer.tokenize是FullTokenizer的方法。

不同的任务可能含有的句子不一样,上面代码的意思就是若b不为空,那么max_length = 总长度 - 3,原因注释已有;若b为空,则就需要减去2即可。

_truncate_seq_pair进行一个截断操作,里面用了pop(),这个是列表方法,把列表最后一个取出来,英文注释也说了为什么没有按照比例截断,若一个序列很短,那按比例截断会流失信息较多,因为比例是长短序列通用的。同时,_truncate_seq_pair还保证了a,b长度一致。若b为空,a则不需要调用这个方法,直接列表方法取就好。

Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...