用 BERT 精简版 DistilBERT+TF.js,提升问答系统 2 倍性能
特邀博文 / 软件工程师 Pierric Cistac;研究员 Victor Sanh;技术主管 Anthony Moi,来自 Hugging Face
Hugging Face (https://huggingface.co/) 是一家 AI 创业公司,旨在通过开发工具来提高社区内的协作效率,并积极参与研究工作,从而为自然语言处理 (NLP) 做出贡献。
NLP 领域充满着困难和挑战,我们认为只有所有参与者彼此分享研究内容和成果,才能攻克难关。于是,我们创建了 Transformers。许多公司的研究人员和工程师都在使用这一领先的 NLP 库,累计下载量超过 200 万。借助此 NLP 库,庞大的国际 NLP 社区将能以多种语言(当然包括英语,同时也包括法语、意大利语、西班牙语、德语、土耳其语、瑞典语、荷兰语、阿拉伯语等等)为不同种类的任务(文本/词条生成、文本分类、问题解答……)快速实验、迭代、创建和发布新模型!目前,Transformers 可提供 300 多种不同的模型。
- Transformers https://github.com/huggingface/transformers
虽然将 Transformers 用于研究场景非常方便,但我们也正在努力将其用在 NLP 的生产方面,寻找及实现可在任意环境中简化采用过程的解决方案。在本文中,我们将展示我们认为可以帮助实现这一目标的一种方法:使用“小型”但性能卓越的模型(例如 DistilBERT),以及针对不同于 Python 的生态系统的框架(例如通过 TensorFlow.js 使用的 Node.js)。
- TensorFlow.js https://tensorflow.google.cn/js
对小型模型的需求:DistilBERT
“低资源”模型是我们较为感兴趣的领域之一,这类模型能够取得与最佳水平 (SOTA) 相近的结果,同时还能保持更小的体量和更快的运行速度。因此,我们创建了 DistilBERT(BERT 的精简版):在参数减少 40%、运行速度提高 60% 的同时,该模型仍能保留 BERT 97% 的性能(据 GLUE 语言理解基准测得)。
不同时期的 NLP 模型及其参数数量
为创建 DistilBERT,我们向 BERT 应用了知识蒸馏技术,因而模型得名 DistilBERT。知识蒸馏是一种压缩技术,由 Hinton 等人提出。该技术通过训练小型模型,来重现较大模型(或模型集合)的行为。
- Hinton 等人 https://arxiv.org/abs/1503.02531
在师生(teacher-student)训练中,我们通过训练学生网络,模仿了教师网络的全部输出分布(即知识)。相较于对硬目标(正确类的独热编码 (one-hot encoding))进行交叉熵训练,我们选择通过对软目标(教师的概率分布)进行交叉熵训练,将知识从教师传递到学生。我们的训练损失因此变为:
其中 t 为来自教师的 logit,s 是学生的 logit
我们的学生网络是 BERT 的小型版本,其中移除了词条类 (Token Type) 嵌入向量和 pooler(用于下一句分类任务)。架构的其余部分则保持不变,同时充分利用学生和教师之间的共有隐藏层的大小,从两层中去除一层以减少层数。我们使用梯度累积,配合动态遮罩对 DistilBERT 进行大批次训练(每批最多 4000 个示例),并移除了下一句预测目标。
这样,我们就可以针对特定的问答任务微调模型。我们利用在 SQuAD 1.1 上微调过的 BERT-cased 模型作为教师,配合知识蒸馏损失便可实现 DistilBERT 的微调。换句话说,问答模型经过蒸馏,便可成为以往使用知识蒸馏预训练完成的语言模型!这样,就会得到很多教师与学生的对应关系:首先由 BERT-cased 教授 DistilBERT-cased,然后由 SQuAD-finetuned BERT-cased 版本“再教一次”,以获得 DistilBERT-cased-finetuned-squad 模型。
- BERT-cased https://github.com/google-research/bert
考虑到网络规模,我们得到的性能结果非常有趣:DistilBERT-cased fine-tuned 模型在开发集上的 F1 得分为 87.1,只比完整的 BERT-cased fine-tuned 模型少 2 分!(F1 得分 88.7)。
如果您想详细了解蒸馏过程,可以参阅我们的专题文章。
- 专题文章 https://medium.com/huggingface/distilbert-8cf3380435b5
独立于语言的格式需求:SavedModel
经过上述处理,我们最终得到的是一个 240MB 的 Keras 文件 (.h5),其中包含 DistilBERT-cased-squad 模型的权重。在这种格式下,模型的架构位于关联的 Python 类中。但是我们的最终目标是尽可能在更多环境中使用此模型(此文中为 Node.js + TensorFlow.js),而 TensorFlow SavedModel 格式非常适合此目标:其本身是一种“序列化”格式,这意味着运行模型所需的所有信息都包含在模型文件中。同时,SavedModel 也是独立于语言的格式,因此我们可以在 Python、JS、C++ 和 Go 中使用。
- Python 类 https://github.com/huggingface/transformers/blob/18eec3a9847da4c879a3af8c5a57e9aaf70adf6d/src/transformers/modeling_tf_distilbert.py#L785
- SavedModel https://tensorflow.google.cn/guide/saved_model
如要将格式转换为 SavedModel,我们首先需要根据模型代码构图。在 Python 中,我们可以使用 tf.function
来达到此目的:
import tensorflow as tf
from transformers import TFDistilBertForQuestionAnswering
distilbert = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad')
callable = tf.function(distilbert.call)v
- tf.function https://tensorflow.google.cn/guide/function
这里,我们将 Keras 模型中调用的函数call
传递给 tf.function
。然后,返回的是一个 callable。借助 get_concrete_function
,我们可以用 callable 跟踪带有特定签名和形状的 call 函数:
concrete_function = callable.get_concrete_function([tf.TensorSpec([None, 384], tf.int32, name="input_ids"), tf.TensorSpec([None, 384], tf.int32, name="attention_mask")])
通过调用 get_concrete_function
,我们将模型的 TensorFlow 算子跟踪编译为由两个形状张量 [None, 384]
(第一个是输入 ID,第二个是注意力遮罩)组成的输入签名。
然后,我们便可将模型保存为 SavedModel 格式:
tf.saved_model.save(distilbert, 'distilbert_cased_savedmodel', signatures=concrete_function)
通过 TensorFlow,只需 4 行代码便能完成格式转换!我们可以使用以下代码来检查生成的 SavedModel 是否包含正确的签名:
saved_model_cli:
$ saved_model_cli show --dir distilbert_cased_savedmodel --tag_set serve --signature_def serving_default
输出:
The given SavedModel SignatureDef contains the following input(s):
inputs['attention_mask'] tensor_info:
dtype: DT_INT32
shape: (-1, 384)
name: serving_default_attention_mask:0
inputs['input_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 384)
name: serving_default_input_ids:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 384)
name: StatefulPartitionedCall:0
outputs['output_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 384)
name: StatefulPartitionedCall:1
Method name is: tensorflow/serving/predict
漂亮!您可以使用此 Colab Notebook,自行体验转换代码的效果。现在,我们可以将 TensorFlow.js 与 SavedModel 配合使用了!
- Colab Notebook https://colab.research.google.com/github/huggingface/node-question-answering/blob/master/DistilBERT_to_SavedModel.ipynb
Node.js 中的 ML :TensorFlow.js
在 Hugging Face,我们坚信,要完全发挥 NLP 的潜力并且让更多人可以轻松使用,必须在生产阶段采用比 Python 使用率更高的其他语言来完成 NLP任务。其 API 要足够简单,让没有机器学习博士学位的软件工程师也可轻松驾驭;Javascript 显然是符合这一条件的语言之一。
利用 TensorFlow.js 提供的 API,与我们之前在 Node.js 中创建的 SavedModel 进行交互将变得非常简单。以下是 NPM 问答包中经过略加简化的 Typescript 代码版本:
const model = await tf.node.loadSavedModel(path); // Load the model located in path
const result = tf.tidy(() => {
// ids and attentionMask are of type number[][]
const inputTensor = tf.tensor(ids, undefined, "int32");
const maskTensor = tf.tensor(attentionMask, undefined, "int32");
// Run model inference
return model.predict({
// “input_ids” and “attention_mask” correspond to the names specified in the signature passed to get_concrete_function during the model conversion
“input_ids”: inputTensor, “attention_mask”: maskTensor
}) as tf.NamedTensorMap;
});
// Extract the start and end logits from the tensors returned by model.predict
const [startLogits, endLogits] = await Promise.all([
result[“output_0"].squeeze().array() as Promise,
result[“output_1”].squeeze().array() as Promise
]);
tf.dispose(result); // Clean up memory used by the result tensor since we don’t need it anymore
- NPM 问答包 https://www.npmjs.com/package/question-answering
请注意,我们用到了 tf.tidy
这个非常有用的 TensorFlow.js 函数,该函数负责在返回模型推断结果时自动清除中间张量,例如 inputTensor
和 maskTensor
。
如何知道自己需要使用 "ouput_0"
和 "output_1"
,以从模型返回的结果中提取开始和结束 logit(回答问题的可能跨度的起点和终点)?只需在导出到 SavedModel 后,运行 saved_model_cli
命令,查看输出的名称即可。
快速易用的分词器:? Tokenizer
构建 Node.js 库时,我们的目标是使 API 尽可能简单。正如上述示例所示,在 TensorFlow.js 帮助下,拥有 SavedModel 可以让模型推理变得非常简单。现在,最困难的部分是将正确格式中的数据传递到输入 ID 和注意力遮罩张量。我们从用户那里收集的数据通常是一个字符串,但是张量需要数字数组,因此我们需要将用户输入的内容词条化。
探索 ? Tokenizer:使用 Rust 编写,是 Hugging Face 正在开发的高性能库。通过该库,您可以非常轻松地使用不同的分词器,例如 BertWordpiece。借助提供的 bindings,您也可以在 Node.js 中使用该库:
const tokenizer = await BertWordPieceTokenizer.fromOptions({
vocabFile: vocabPath, lowercase: false
});
tokenizer.setPadding({ maxLength: 384 }); // 384 matches the shape of the signature input provided while exporting to SavedModel
// Here question and context are in their original string format
const encoding = await tokenizer.encode(question, context);
const { ids, attentionMask } = encoding;
- ? Tokenizer https://github.com/huggingface/tokenizers
- bindings https://www.npmjs.com/package/tokenizers
就这么简单!只需 4 行代码,我们就可以完成对用户输入内容的转换,而转换后的格式可以通过 TensorFlow.js 为模型喂数据。
在 Node.js 中实现强大的问答性能
得益于强大的 SavedModel 格式、用于推理的 TensorFlow.js 以及用于词条化的分词器,我们可以在 NPM 包中提供颇为简单而又功能强大的公共 API,从而实现当初的既定目标:
import { QAClient } from "question-answering"; // If using Typescript or Babel
// const { QAClient } = require("question-answering"); // If using vanilla JS
const text = `
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season.
The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.
As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
`;
const question = "Who won the Super Bowl?";
const qaClient = await QAClient.fromOptions();
const answer = await qaClient.predict(question, text);
console.log(answer); // { text: 'Denver Broncos', score: 0.3 }
性能强大?当然!借助 TensorFlow.js 对 SavedModel 格式的原生支持,我们可以获得非常出色的性能:下方所示的基准是对 Node.js 包和热门 Transformer Python 库的比较,两者运行的是相同的 DistilBERT-cased-squad 模型。正如您所见,速度提升了 2 倍!谁说 JavaScript 运行慢?
短文本指长度在 500 到 1000 个字符之间的文本,长文本指长度在 4000 到 5000 个字符之间的文本。您可以查看 Node.js 基准脚本(Python 版本的脚本与之相同)。基准运行配置:标准 2019 MacBook Pro,系统版本为 macOS 10.15.2
- Node.js 基准脚本 https://github.com/huggingface/node-question-answering/blob/master/scripts/benchmark.js
对于 NLP 来说,现在是一个充满机遇的时刻:一方面,大型模型(例如 GPT2 或 T5)的功能越来越完善;另一方面,相关研究也越来越受到关注,以“缩小”那些性能良好、但笨重又昂贵的模型,蒸馏便是其中一种颇受重视的方法。
此外,利用一些公式工具(例如 Javascript 生态系统中的 TensorFlow.js),让大型开发者社区参与到这场变革中来,NLP 的未来会比以往更激动人心、更便于生产!
如需了解详情,您可以访问我们的 GitHub 代码库 (https://github.com/huggingface)。
- Linux性能分析工具与图形化方法
- MySQL和Oracle中的隐式转换(r6笔记第45天)
- R语言的数据导入与导出(write.table,CAT)
- gqlplus的简单使用(r6笔记第43天)
- Java基础-21(01)总结字符流,IO流编码问题,实用案例必做一遍
- zabbix中配置dg的监控(r6笔记第62天)
- Apache ActiveMQ 远程代码执行漏洞 (CVE-2016-3088)分析
- mysql5.5与mysq 5.6中禁用innodb引擎的方法
- 缓慢的update语句性能分析(r6笔记第61天)
- 一个dg警告发现的硬件问题 (r6笔记第60天)
- mysql几种存储引擎介绍
- Java基础-21(02)总结字符流,IO流编码问题,实用案例必做一遍
- DeDeCMS v5.7 密码修改漏洞分析
- Java基础-20(01)总结,递归,IO流
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Python opencv图像处理基础总结(三) 图像直方图 直方图应用 直方图反向投影
- Python opencv图像处理基础总结(四) 模板匹配 图像二值化
- python pyecharts数据可视化 词云图 仪表盘 水球图
- python jupyter notebook配置 更改默认工作目录 更换皮肤主题 代码字体 大小
- 关于直播卖货系统平台在微信浏览器中音视频播放的问题
- python爬虫 scrapy爬虫框架的基本使用
- Python opencv图像处理基础总结(五) 图像金字塔 图像梯度 Canny算法边缘提取
- python scrapy爬虫练习(1) 爬取豆瓣电影top250信息
- python爬虫 senlenium爬取拉勾网招聘数据
- Python opencv图像处理基础总结(六) 直线检测 圆检测 轮廓发现
- 简单又强大的pandas爬虫 利用pandas库的read_html()方法爬取网页表格型数据
- python pyecharts数据可视化 折线图 箱形图
- Python爬虫 selenium自动化 利用搜狗搜索爬取微信公众号文章信息
- python 办公自动化系列 (1) 从22053条数据中统计断网次数并计算平均断网时间
- Python数据可视化 热力图