【NLP笔记】预训练+微调范式之OpenAI Transformer、ELMo、ULM-FiT、Bert..

文章目录

  • OpenAI Transformer
  • ELMo
  • ULM-FiT
  • Bert
    • 基础结构
    • Embedding
    • 预训练&微调

【原文链接】: BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding

【本文参考链接】

  • The Illustrated BERT, ELMo, and co. (How NLP Cracked Transfer Learning)
  • 【NLP从入门到大模型】5.图解Bert

【其他值得关注的预训练+微调范式模型】

  • 【深度学习】GPT-1
  • GPT系列:GPT-2详解
  • RoBERTa 详解(Bert改进模型)
  • XLNet模型(Bert改进模型)
  • ALBert 详解(Bert改进模型)

2018年前后有很多著名的NLP相关的论文产出,形成了一种预训练语义表征模型+下游任务微调的范式,其中长期内影响力较大的之一Bert。

OpenAI Transformer

OpenAI Transformer(OpenAI的Transformer模型)是在2017年底提出的。它是由OpenAI团队提出的一种基于Transformer架构的语言模型,用于处理自然语言处理任务。OpenAI Transformer是一个堆叠的12层Decoder结构(Decoder-Only),在解码器的注意力模块中,由于没有Encoder,所以就不是Cross-Attention,改为self-attention+mask的方式。
在这里插入图片描述
整个训练过程是无监督的,使用大量的未标记数据集训练,将预测下一个词汇作为模型的预测目标。然后预训练好的Transformer结构(Decoder-Only)后面再加上简单的网络结构进行参数调整,并设定目标场景的预测结果,就实现了通过微调去适应下游任务。
在这里插入图片描述

ELMo

ELMo(Embedding for Language Model)是由斯坦福大学的研究团队在2018年提出的,是一种基于双向 LSTM(长短期记忆)模型的语言表示方法,通过在大规模语料库上训练的双向 LSTM 模型来生成深度上下文相关的词向量。
在这里插入图片描述

大语言模型通过海量数据训练出来的Tokenizer可直接用于文本向量化,开源的大模型Tokenizer使得大家能通过下载已经开源的词嵌入模型来完成NLP的任务。Word2vec是将单词固定为指定长度的向量,和以往的经典词嵌入模型不同的是,ELMo则是在为每个单词分配词向量之前先查看整个句子,然后使用bi-LSTM结合上下文语境来训练它对应的词向量。
在这里插入图片描述

ELMo会训练一个模型,这个模型接受一个句子或者单词的输入,输出最有可能出现在后面的一个单词,采用海量语料进行无监督训练。ELMo通过下图的方式将hidden states的状态组合起来提炼出具有语境意义的词嵌入方式(全连接后加权求和)。
在这里插入图片描述

ULM-FiT

ULM-FiT(Universal Language Model Fine-tuning)是一种用于迁移学习的技术,由 Jeremy Howard 和 Sebastian Ruder 在2018年提出。它基于递归神经网络(RNN)构建的语言模型,并使用了一系列创新的技术来提高迁移学习的效果。

ULM-FiT 的核心思想是首先在大规模的通用语料库上训练一个通用的语言模型,然后通过微调(Fine-tuning)的方式,将这个通用模型应用于特定的任务。具体来说,ULM-FiT 包含以下几个关键步骤:

预训练通用语言模型:首先,在大规模的通用语料库上训练一个通用的语言模型,比如基于AWD-LSTM(ASGD Weight-Dropped LSTM)的语言模型。这个通用模型可以捕捉语言的一般规律和语义信息。

微调:接下来,将预训练的通用模型在特定任务的数据集上进行微调,以适应特定任务的语言模式和特征。微调的过程包括在任务数据上进行进一步的训练,并对模型进行调整以提高任务性能。

多阶段微调策略:ULM-FiT 提出了一种多阶段微调的策略,其中包括逐层解冻和差异学习率调整等技术,以提高微调的效果。这种策略可以使模型逐步适应新任务的特征,从而更好地利用预训练的知识。

ULM-FiT 已经在许多自然语言处理任务中取得了显著的成功,包括文本分类、命名实体识别、情感分析等。它的主要优势在于通过预训练通用模型和微调的方式,能够在少量标注数据的情况下快速有效地解决特定任务。

Bert

BERT 是由Google在2018年10月发布的。它是一种基于Transformer模型的语言表示方法,通过在大规模文本数据上进行预训练来学习文本的语义表示。BERT 通过Masked Language Modeling(MLM)和Next Sentence Prediction(NSP)等任务来训练模型,可以捕捉文本序列中的双向语境信息,取得了在多项自然语言处理任务上的领先性能。Bert的训练模式是半监督预训练(Pre-train)+有监督的微调(fine-tuning),预训练出一个强大的文本表征模型,然后根据具体任务进行微调,实现在特定任务上的应用。Goole开源这个模型,并提供预训练好的模型,这使得所有人都可以通过它来构建一个涉及NLP的算法模型,节约了大量训练语言模型所需的时间、精力、知识和资源。

在这里插入图片描述

基础结构

Bert的基础架构是一个堆叠多层的Encoder架构,参数量更大的Bert就是对应的层数更多(如:Bert-base有12层Encoder,Bert-Large共24层Encoder),Bert就是一种Encoder-Only的架构。
在这里插入图片描述
BERT与Transformer 的编码方式一样。将固定长度的字符串作为输入,数据由下而上传递计算,每一层都用到了self attention,并通过前馈神经网络传递其结果,将其交给下一个编码器。
在这里插入图片描述
每个位置返回的输出都是一个隐藏层大小的向量(Base版本BERT为768),该输出向量可以作为下游微调任务的输入向量,在论文中指出使用单层神经网络作为分类器就可以取得很好的效果。
在这里插入图片描述

Embedding

BERT 模型中的词向量(Word Embedding)、位置编码(Positional Encoding)、以及段落编码(Segment Embedding)是通过一系列的处理步骤得到的:

  1. 词向量(Word Embedding)
    BERT 使用了 WordPiece 或 Byte-Pair Encoding(BPE)等分词算法将输入文本序列分割成词语或子词(subwords)。然后,每个词语或子词都被映射为一个固定长度的向量,这个过程就是词向量。BERT 模型使用了多层的词向量表示,其中包括 WordPiece Embeddings 和 Token Type Embeddings。WordPiece Embeddings 将每个词语或子词映射为一个固定长度的向量,而 Token Type Embeddings 用于区分不同句子之间的关系(同Transformer)。
  2. 位置编码(Positional Encoding)
    在 BERT 模型中,位置编码用于向词向量中添加关于词语或子词在序列中位置的信息。BERT 使用了一种特殊的位置编码方法,通常是通过加法或者叠加来将位置编码添加到词向量中。这种位置编码能够为模型提供关于词语在序列中相对位置的信息,从而帮助模型理解文本的顺序关系(同Transformer)。
  3. 段落编码(Segment Embedding)
    在一些需要处理多个句子的任务中,BERT 还会添加段落编码以区分不同句子之间的关系。段落编码通常是一个固定长度的向量,用于区分不同句子所属的段落或者句子之间的关系。在输入序列中,通过在句子的开头添加特殊的标记(比如 [CLS] 标记)来表示句子的开始,然后添加段落编码来区分不同句子之间的关系(如:第一个句子编码段落token全为0,第二个全为1)。

在这里插入图片描述

预训练&微调

前面有说过Bert的模式也是预训练+下游任务微调,如下图所示:
在这里插入图片描述

  1. 预训练
  • Masked Language Modeling(MLM):通过随机掩盖序列中的一些token,让模型去预测被掩盖的token是什么。如在训练过程中随机mask 15%的token,其中80%的序列我们用[MASK]去取代要掩码的token,10%的序列我们将待掩码token随机替换为其他token,10%的序列保持不变,最终的损失函数只计算被mask掉的那个token。
  • Next Sentence Prediction(NSP):让模型理解两个句子之间的联系。训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,模型预测B是不是A的下一句。预训练的时候可以达到97-98%的准确度。

模型地址:https://huggingface.co/google-bert/bert-base-chinese

# 安装新版本(>=4.0)的transformer,内置bert
from transformers import AutoTokenizer, AutoModelForMaskedLM
# 加载Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
# 加载模型
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")# MLM任务测试
# BERT 在预训练中引入了 [CLS] 和 [SEP] 标记句子的开头和结尾,准备输入模型的语句
text = '中国的首都是哪里? 北京是 [MASK] 国的首都。'  
inputs = tokenizer(text, return_tensors='pt')
# 获取MASK的位置
mask_position = inputs["input_ids"].tolist()[0].index(tokenizer.mask_token_id)
print("tokenized txt: {}".format(inputs))
# 使用模型进行预测
with torch.no_grad():outputs = bert_model(**inputs).logits# 获取MASK位置的预测分数,并找到预测得分最高的token
predicted_token_idx = outputs[0, mask_position].argmax().item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_token_idx])[0]print(f"The masked token in '{text}' is predicted as '{predicted_token}'")# NSP任务测试
model_path = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForNextSentencePrediction.from_pretrained(model_path)sen_code1 = tokenizer.encode_plus('今天天气怎么样', '今天天气很好')
sen_code2 = tokenizer.encode_plus('小明今年几岁了', '今天天空很蓝')
tokens_tensor = torch.tensor([sen_code1['input_ids'], sen_code2['input_ids']])model.eval()outputs = model(tokens_tensor)
seq_relationship_scores = outputs[0]  # seq_relationship_scores.shape= torch.Size([2, 2])
sample = seq_relationship_scores.detach().numpy()  # sample.shape = (2, 2)pred = np.argmax(sample, axis=1)
print(pred)
# [0 1] 0表示是上下句,1表示不是上下句
  1. 微调
    在这里插入图片描述
    Bert可以实现的几种微调任务:短文本相似 、文本分类、QA机器人、语义标注等微调任务,下面以一个任务为例:
# 下载文本分类数据集
wget http://github.com/skdjfla/toutiao-text-classfication-dataset/raw/master/toutiao_cat_data.txt.zip
import codecs
import jsonimport torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmupclass NewsDataset(Dataset):"""新闻分类数据集"""def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labels# 读取单个样本def __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(int(self.labels[idx]))return itemdef __len__(self):return len(self.labels)class NewsClassifier(object):"""新闻分类器"""@classmethoddef finetuning(cls):"""微调新闻分类器"""model_path = 'bert-base-chinese'# 加载预训练的向量化模型tokenizer = BertTokenizer.from_pretrained(model_path)# 加载预训练的分类模型model = BertForSequenceClassification.from_pretrained(model_path, num_labels=17)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 加载数据集train_dataset, test_dataset = cls.load_dataset(tokenizer)# 单个读取到批量读取train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)# 优化方法optim = torch.optim.AdamW(model.parameters(), lr=2e-5)total_steps = len(train_loader) * 1scheduler = get_linear_schedule_with_warmup(optim,num_warmup_steps=0,  # Default value in run_glue.pynum_training_steps=total_steps)for epoch in range(1):print("------------Epoch: %d ----------------" % epoch)cls.train(model, train_loader, optim, device, scheduler, epoch)cls.validation(model, test_dataloader, device)model.save_pretrained("output/bert-news-clf")tokenizer.save_vocabulary("output/bert-news-clf")@classmethoddef train(cls, model, train_loader, optim, device, scheduler, epoch):"""训练函数"""model.train()total_train_loss = 0iter_num = 0total_iter = len(train_loader)for batch in train_loader:# 正向传播optim.zero_grad()input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs[0]total_train_loss += loss.item()# 反向梯度信息loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)# 参数更新optim.step()scheduler.step()iter_num += 1if iter_num % 100 == 0:print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (epoch, iter_num, loss.item(), iter_num / total_iter * 100))print("Epoch: %d, Average training loss: %.4f" % (epoch, total_train_loss / len(train_loader)))@classmethoddef flat_accuracy(cls, preds, labels):"""计算准确率"""pred_flat = np.argmax(preds, axis=-1).flatten()labels_flat = labels.flatten()return np.sum(pred_flat == labels_flat) / len(labels_flat)@classmethoddef validation(cls, model, test_dataloader, device):"""验证函数"""model.eval()total_eval_accuracy = 0total_eval_loss = 0for batch in test_dataloader:with torch.no_grad():# 正常传播input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs[0]logits = outputs[1]total_eval_loss += loss.item()logits = logits.detach().cpu().numpy()label_ids = labels.to('cpu').numpy()total_eval_accuracy += cls.flat_accuracy(logits, label_ids)avg_val_accuracy = total_eval_accuracy / len(test_dataloader)print("Accuracy: %.4f" % avg_val_accuracy)print("Average testing loss: %.4f" % (total_eval_loss / len(test_dataloader)))print("-------------------------------")@classmethoddef load_dataset(cls, tokenizer):"""加载新闻分类数据集"""data_path = 'dataset/toutiao_cat_data.txt'# 标签news_label = [int(x.split('_!_')[1]) - 100 for x in codecs.open(data_path)]# 存储标签label_text_json = {}for x in codecs.open(data_path):groups = x.split('_!_')label_text_json[int(x.split('_!_')[1]) - 100] = groups[2]with open('output/news_label.json', 'w') as file_obj:json.dump(label_text_json, file_obj, indent=4)# 文本news_text = [x.strip().split('_!_')[-1] if x.strip()[-3:] != '_!_' else x.strip().split('_!_')[-2]for x in codecs.open(data_path)]# 划分为训练集和验证集# stratify 按照标签进行采样,训练集和验证部分同分布x_train, x_test, train_label, test_label = train_test_split(news_text[:50000],news_label[:50000],test_size=0.2,stratify=news_label[:50000])train_encoding = tokenizer(x_train, truncation=True, padding=True, max_length=64)test_encoding = tokenizer(x_test, truncation=True, padding=True, max_length=64)train_dataset = NewsDataset(train_encoding, train_label)test_dataset = NewsDataset(test_encoding, test_label)return train_dataset, test_dataset@classmethoddef predict(cls):"""微调后的模型预测函数"""s = '金价下跌预示着什么?'model_path = 'output/bert-news-clf'tokenizer = BertTokenizer.from_pretrained(model_path)model = BertForSequenceClassification.from_pretrained(model_path, num_labels=17)sen_code = tokenizer.encode_plus(s)tokens_tensor = torch.tensor([sen_code['input_ids']])attention_mask = torch.tensor([sen_code['attention_mask']])model.eval()outputs = model(tokens_tensor, attention_mask)outputs = outputs[0].detach().numpy()print(outputs)outputs = outputs[0].argmax(0)print(outputs)# 4with open('output/news_label.json') as file_obj:news_label_json = json.load(file_obj)print(news_label_json[str(outputs)])# news_financeif __name__ == '__main__':# 微调NewsClassifier.finetuning()# 调用微调的模型进行预测NewsClassifier.predict()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/284532.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

提面 | 面试抽题

学习到更新日期面试抽题-1.2案例分析的思维本质2024-3-23 1提面抽屉论述问题的分类 1.1案例分析占总论 1.2案例分析的思维本质

计算机网络相关

OSI七层模型 各层功能: TCP/IP四层模型 应用层 传输层 网络层 网络接口层 访问一个URL的全过程 在浏览器中输入指定网页的 URL。 浏览器通过 DNS 协议,获取域名对应的 IP 地址。 浏览器根据 IP 地址和端口号,向目标服务器发起一个 TCP…

【进阶五】Python实现SDVRP(需求拆分)常见求解算法——离散粒子群算法(DPSO)

基于python语言,采用经典离散粒子群算法(DPSO)对 需求拆分车辆路径规划问题(SDVRP) 进行求解。 目录 往期优质资源1. 适用场景2. 代码调整3. 求解结果4. 代码片段参考 往期优质资源 经过一年多的创作,目前已…

linux文本三剑客 --- grep、sed、awk

1、grep grep&#xff1a;使用正则表达式搜索文本&#xff0c;将匹配的行打印出来&#xff08;匹配到的标红&#xff09; 命令格式&#xff1a;grep [option] pattern file <1> 命令参数 -A<显示行数>&#xff1a;除了显示符合范本样式的那一列之外&#xff0c;并…

C语言中的联合体和枚举

联合体 联合体的创建 联合体的关键字是union union S {char a;int i; };除了关键字和结构体不一样之外&#xff0c;联合体的创建语法形式和结构体的很相似&#xff0c;如果不熟悉结构体的创建&#xff0c;可以看一下我上一篇的博客关于结构体知识的详解。 联合体的特点 联合…

Personal Website

Personal Website Static Site Generators hexo hugo jekyll Documentation Site Generator gitbook vuepress vitepress docsify docute docusaurus Deployment 1. GitHub Pages 2. GitLab Pages 3. vercel 4. netlify Domain 域名注册 freessl 域名解析域名…

【GUI】自动化办公

目录 一、GUI介绍 二、环境安装 三、鼠标移动操作 四、鼠标点击操作 五、拖动鼠标 六、鼠标滚动操作 七、屏幕快照&图像识别基础 7.1 屏幕快照&#xff08;截图&#xff09; 7.2 图像识别 八、键盘控制 一、GUI介绍 GUI自动化就是写程序直接控制键盘和鼠标。这些…

电脑如何关闭自启动应用?cmd一招解决问题

很多小伙伴说电脑刚开机就卡的和定格动画似的&#xff0c;cmd一招解决问题&#xff1a; CtrlR打开cmd,输入&#xff1a;msconfig 进入到这个界面&#xff1a; 点击启动&#xff1a; 打开任务管理器&#xff0c;禁用不要的自启动应用就ok了

【prometheus-operator】k8s监控集群外redis

1、部署exporter GitHub - oliver006/redis_exporter: Prometheus Exporter for Redis Metrics. Supports Redis 2.x, 3.x, 4.x, 5.x, 6.x, and 7.x redis_exporter-v1.57.0.linux-386.tar.gz # 解压 tar -zxvf redis_exporter-v1.57.0.linux-386.tar.gz # 启动 nohup ./redi…

SpringCloud-记

目录 什么是SpringCloud 什么是微服务 SpringCloud的优缺点 SpringBoot和SpringCloud的区别 RPC 的实现原理 RPC是什么 eureka的自我保护机制 Ribbon feigin优点 Ribbon和Feign的区别 什么是SpringCloud Spring Cloud是一系列框架的有序集合。它利用Spring Boot的开发…

力扣面试150 阶乘后的零 数论 找规律 质因数

Problem: 172. 阶乘后的零 思路 &#x1f468;‍&#x1f3eb; 大佬神解 一个数末尾有多少个 0 &#xff0c;取决于这个数 有多少个因子 10而 10 可以分解出质因子 2 和 5而在阶乘种&#xff0c;2 的倍数会比 5 的倍数多&#xff0c;换而言之&#xff0c;每一个 5 都会找到一…

AI时代Python金融大数据分析实战:ChatGPT让金融大数据分析插上翅膀

❤️作者主页&#xff1a;小虚竹 ❤️作者简介&#xff1a;大家好,我是小虚竹。2022年度博客之星评选TOP 10&#x1f3c6;&#xff0c;Java领域优质创作者&#x1f3c6;&#xff0c;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;掘金年度人气作…

电脑照片分辨率怎么调?这款dpi修改工具好用

许多考试平台在上传证件照片的时候&#xff0c;大多都会对图片分辨率有具体要求&#xff0c;但是如果遇上手上的图片分辨率达不到要求&#xff0c;那么怎么改图片分辨率呢&#xff1f;可以利用专业的dpi修改工具来处理&#xff0c;比如今天分享的就是一个在线修改图片分辨率的方…

唯众物联网安装调试员实训平台物联网一体化教学实训室项目交付山东技师学院

近日&#xff0c;山东技师学院物联网安装调试员实训平台及物联网一体化教学实训室采购项目已顺利完成交付并投入使用&#xff0c;标志着学院在物联网技术教学与实践应用方面迈出了坚实的一步。 山东技师学院作为国内知名的技师培养摇篮&#xff0c;一直以来致力于为社会培养高…

select , poll, epoll思维导图

目录 1. 总的框架结构 2. select 3. poll 4. epoll 1. 总的框架结构 2. select

自动驾驶轨迹规划之时空语义走廊(一)

欢迎大家关注我的B站&#xff1a; 偷吃薯片的Zheng同学的个人空间-偷吃薯片的Zheng同学个人主页-哔哩哔哩视频 (bilibili.com) 目录 1.摘要 2.系统架构 3.MPDM 4.时空语义走廊 ​4.1 种子生成 4.2 具有语义边界的cube inflation ​4.3 立方体松弛 本文解析了丁文超老师…

解决Animate.css动画效果无法在浏览器运行问题

背景 在开发官方网站的时候&#xff0c;临时更换了电脑&#xff0c;发现原本正常的动画效果突然不动了。 经过 chrome、Microsoft Edge都无法运行。 Animate.css | A cross-browser library of CSS animations. 问题排查 通过审查元素后发现类名是注入并且生效的。 验证 然…

视频素材库在哪里找无水印?推荐几个优质高清视频素材网

在探索创作无限可能的旅程中&#xff0c;每位视频创作者都渴望找到那些能够让作品生动起来、讲述更加动人故事的素材。为此&#xff0c;我继续为你介绍一系列的无水印高清素材网站&#xff0c;它们不仅能为你的视频添加视觉与听觉的魅力&#xff0c;还能帮助你以更高效的方式实…

阿里云2核4G服务器支持多少人在线?2C4G多少钱一年?

2核4G服务器支持多少人在线&#xff1f;阿里云服务器网账号下的2核4G服务器支持20人同时在线访问&#xff0c;然而应用不同、类型不同、程序效率不同实际并发数也不同&#xff0c;2核4G服务器的在线访问人数取决于多个变量因素。 阿里云2核4G服务器多少钱一年&#xff1f;2核4…

分布式ID生成方案总结

分布式场景下&#xff0c;需要保证每一个服务拿到的id是唯一的。本文讨论、分析、总结了一些常见的分布式ID生成方案 结论&#xff1a;技术上没有银弹&#xff0c;每种分布式id都有自己的使用场景。uuid适用于业务比较简单&#xff0c;对性能没有太高追求等。 目前主流是 基于数…