深度学习:GPT-1的MindSpore实践

GPT-1简介

GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点:

  • NLP领域的迁移学习:通过最少的任务专项数据,利用预训练模型出色地完成具体的下游任务。
  • 语言建模作为预训练任务:使用无监督学习和大规模的文本语料库来训练模型
  • 为具体任务微调:采用预训练模型来适应监督任务

和BERT类似,GPT-1同样采取pre-train + fine-tune的思路:先基于大量未标注语料数据进行预训练, 后基于少量标注数据进行微调。但GPT-1在预训练任务思路和模型结构上与BERT有所差别。

GPT-1的目标是在预训练的过程中根据现有的所有词元,预测下一个词元。这个任务被称为“自回归语言建模”。

一个简单的例子:

输入序列为:“The sun rises in the”

训练数据的原句子为:“The sun rises in the east”

所以我们的目标输出为:“east”

将输入序列输入GPT模型,GPT根据输入预测下一个词元(“east”)在语料库中的概率分布

正确词元“east”作为一个“伪标签”来帮助模型训练

模型架构

GPT主要使用Transformer Decoder架构,但因为没有Encoder,所以在Transformer Decoder的基础上移除了计算Encoder与Decoder间注意力分数的Multi-Head Attention Layer。

Masked Multi-HeadSelf-Attention

Masked Multi-Head Self-Attention 是Multi-Head Attetion的变种。 最大的不同来自于MMSA的掩码机制,掩码机制防止模型通过观测未来的词元以进行“作弊”。

一个掩码词元<mask>被用于注意力分数矩阵,所以当前词元只能注意到序列中自己和自己之前的词元。未来的次元的注意力分数将被设为0以确保其在Softmax步骤后的实际贡献为0。

为什么掩码机制非常重要?

对于自回归任务,模型必须线性地生成词元,不能基于未来的信息预测下一个词元。

损失函数

GPT使用Cross-Entropy Loss作为损失函数:\mathcal{L} = - \sum_{t=1}^N \log P(w_t | w_1, w_2, \dots, w_{t-1})

交叉熵损失是这项任务的理想选择,因为它通过测量预测的概率分布与真实分布的距离来惩罚不正确的预测。它自然适于处理多类分类任务,其中模型从大量词汇表中选择一个标记。

模型输入

GPT-1的输入同样为句子或句子对,并添加Special Tokens。

  • [BOS]:表示句子的开始,(论文中给出的token表示为[START]),添加到序列最前;
  • [EOS]:表示序列的结束,(论文中的给出的[EXTRACT]),添加到序列最后,在进行分类任务时,会将 该special token对应的输出接入输出层;我们也可以理解为该token可以学习到整个句子的语义信息;
  • [SEP]:用于间隔句子对中的两个句子;
GPT Embedding 同样分为三类:token Embedding、Position Embedding、Segment Embedding

 

GPT-1模型具体参数

模型架构

  • 12个Transformer Decoder Block
  • hidden_size为768(模型输入和输出的向量纬度)
  • 注意力头数为12
  • FFN维度为3072
  • 词表(Vocab)大小为40000
  • 序列长度为512(上下文窗口长度)

训练过程

  • Adam优化器,超参数为:0.9, 0.99
  • 学习率:最大学习率:2.5x10e-4 使用2000步作为热身,随后线性衰退
  • 批大小:64
  • 梯度剪裁:1.0
  • Dropout率:0.1

训练过程

100000步,大约花费8张NVIDIA V100 GPU训练30天,共有117M参数。使用Xavier初始化,权重衰退为0.01。 

下游任务 

GPT按照生成式的逻辑统一了下游任务的应用模板,使用最后一个token([EOS]or[EXTRACT])对应的hidden state,输出到额外的输出层中,进行分类标签预测。
任务包括:文本分类(情感分类、新闻分类)、文本蕴含(根据前提推出假设)、文本语义相似度、多类选择(在多个next token中进行选择)

基于MindSpore微调GPT-1进行情感分类

# #安装mindnlp 0.4.0套件
# !pip install mindnlp
# !pip uninstall soundfile -y
# !pip install download
# !pip install jieba
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simpleimport osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp.engine import Trainer# loading dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']imdb_train.get_dataset_size()import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasetfrom mindnlp.transformers import OpenAIGPTTokenizer
# tokenizer
gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)#为方便体验流程,把原本数据集的十分之一拿出来体验训练和评估,
imdb_train, _ = imdb_train.split([0.1, 0.9], randomize=False)# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)# load GPT sequence classification model and set class=2
from mindnlp.transformers import OpenAIGPTForSequenceClassification  # Import the GPT model for sequence classification
from mindnlp import evaluate  # Import the evaluation module from MindNLP
import numpy as np  # Import NumPy for numerical operations# Set up the GPT model for sequence classification with 2 output labels (binary classification).
model = OpenAIGPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)# Set the padding token ID in the model configuration to match the tokenizer's padding token ID.
model.config.pad_token_id = gpt_tokenizer.pad_token_id# Resize the token embedding layer to account for any added tokens (e.g., special tokens).
model.resize_token_embeddings(model.config.vocab_size + 3)from mindnlp.engine import TrainingArguments  # Import training arguments for model training configuration.# Define training arguments.
training_args = TrainingArguments(output_dir="gpt_imdb_finetune",  # Directory to save model checkpoints and outputs.evaluation_strategy="epoch",  # Evaluate the model at the end of each epoch.save_strategy="epoch",  # Save model checkpoints at the end of each epoch.logging_strategy="epoch",  # Log metrics and progress at the end of each epoch.load_best_model_at_end=True,  # Automatically load the best model (based on evaluation metrics) at the end of training.num_train_epochs=1.0,  # Number of training epochs (default is 1 for quick experimentation).learning_rate=2e-5  # Learning rate for the optimizer.
)# Load the accuracy metric for evaluation.
metric = evaluate.load("accuracy")# Define a function to compute metrics during evaluation.
def compute_metrics(eval_pred):logits, labels = eval_pred  # Unpack predictions (logits) and true labels.predictions = np.argmax(logits, axis=-1)  # Convert logits to class predictions using argmax.return metric.compute(predictions=predictions, references=labels)  # Compute accuracy metric.# Initialize the Trainer class with the model, training arguments, datasets, and metric computation function.
trainer = Trainer(model=model,  # The GPT model to be fine-tuned.args=training_args,  # Training configuration arguments.train_dataset=dataset_train,  # Training dataset (must be preprocessed and tokenized).eval_dataset=dataset_val,  # Validation dataset for evaluation.compute_metrics=compute_metrics  # Metric computation function for evaluation.
)# start training
trainer.train()trainer.evaluate(dataset_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/478726.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

websocket是什么?

一、定义 Websocket是一种在单个TCP连接上进行全双工通信的协议&#xff0c;它允许服务器主动向客户端推送数据&#xff0c;而不需要客户端不断的轮询服务器来获取数据 与http协议不同&#xff0c;http是一种无状态的&#xff0c;请求&#xff0c;响应模式的协议(单向通信)&a…

1-golang_org_x_crypto_bcrypt测试 --go开源库测试

1.实例测试 package mainimport ("fmt""golang.org/x/crypto/bcrypt" )func main() {password : []byte("mysecretpassword")hashedPassword, err : bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)if err ! nil {fmt.Println(err)…

CSS —— 子绝父相

相对定位&#xff1a;占位&#xff1b;不脱标 绝对定位&#xff1a;不占位&#xff1b;脱标 希望子元素相对于父元素定位&#xff0c;又不希望父元素脱标&#xff08;父元素占位&#xff09; 子级是 绝对定位&#xff0c;不会占有位置&#xff0c; 可以放到父盒子里面的任何一…

风尚云网前端学习:一个简易前端新手友好的HTML5页面布局与样式设计

风尚云网前端学习&#xff1a;一个简易前端新手友好的HTML5页面布局与样式设计 简介 在前端开发的世界里&#xff0c;HTML5和CSS3是构建现代网页的基石。本文将通过一个简单的HTML5页面模板&#xff0c;展示如何使用HTML5的结构化元素和CSS3的样式特性&#xff0c;来创建一个…

django authentication 登录注册

文章目录 前言一、django配置二、后端实现1.新建app2.编写view3.配置路由 三、前端编写1、index.html2、register.html3、 login.html 总结 前言 之前&#xff0c;写了django制作简易登录系统&#xff0c;这次利用django内置的authentication功能实现注册、登录 提示&#xff…

ctfshow单身杯2024wp

文章目录 ctfshow单身杯2024wp签到好玩的PHPezzz_sstiez_inject ctfshow单身杯2024wp 签到好玩的PHP 考点&#xff1a;序列化反序列化 <?phperror_reporting(0);highlight_file(__FILE__);class ctfshow {private $d ;private $s ;private $b ;private $ctf ;public …

ISUP协议视频平台EasyCVR萤石设备视频接入平台银行营业网点安全防范系统解决方案

在金融行业&#xff0c;银行营业厅的安全保卫工作至关重要&#xff0c;它不仅关系到客户资金的安全&#xff0c;也关系到整个银行的信誉和运营效率。随着科技的发展&#xff0c;传统的安全防护措施已经无法满足现代银行对于高效、智能化安全管理的需求。 EasyCVR视频汇聚平台以…

【代码pycharm】动手学深度学习v2-08 线性回归 + 基础优化算法

课程链接 线性回归的从零开始实现 import random import torch from d2l import torch as d2l# 人造数据集 def synthetic_data(w,b,num_examples):Xtorch.normal(0,1,(num_examples,len(w)))ytorch.matmul(X,w)bytorch.normal(0,0.01,y.shape) # 加入噪声return X,y.reshape…

zotero安卓测试版下载和使用

2023年年底&#xff0c;Zotero官方就已经推出了安卓版的测试版Zotero for Android (beta),&#xff0c;但名额有限且只能通过Google商店下载。此外&#xff0c;还有一些第三方开发的安卓应用&#xff0c;如Zoo for Zotero、ZotDroid等。 在首次使用Zotero安卓版时&#xff0c;用…

基于FPGA的2FSK调制-串口收发-带tb仿真文件-实际上板验证成功

基于FPGA的2FSK调制 前言一、2FSK储备知识二、代码分析1.模块分析2.波形分析 总结 前言 设计实现连续相位 2FSK 调制器&#xff0c;2FSK 的两个频率为:fI15KHz&#xff0c;f23KHz&#xff0c;波特率为 1500 bps,比特0映射为f 载波&#xff0c;比特1映射为 载波。 1&#xff09…

Matlab 深度学习 PINN测试与学习

PINN 与传统神经网络的区别 与传统神经网络的不同之处在于&#xff0c;PINN 能够以微分方程形式纳入有关问题的先验专业知识。这些附加信息使 PINN 能够在给定的测量数据之外作出更准确的预测。此外&#xff0c;额外的物理知识还能在存在含噪测量数据的情况下对预测解进行正则…

【JavaScript】JavaScript开篇基础(7)

1.❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; Hello, Hello~ 亲爱的朋友们&#x1f44b;&#x1f44b;&#xff0c;这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章&#xff0c;请别吝啬你的点赞❤️❤️和收藏&#x1f4d6;&#x1f4d6;。如果你对我的…

JavaScript的基础数据类型

一、JavaScript中的数组 定义 数组是一种特殊的对象&#xff0c;用于存储多个值。在JavaScript中&#xff0c;数组可以包含不同的数据类型&#xff0c;如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式&#xff1a; 字面量表示法&#xff1a;let fruits [apple…

WebSocket详解、WebSocket入门案例

目录 1.1 WebSocket介绍 http协议&#xff1a; webSocket协议&#xff1a; 1.2WebSocket协议&#xff1a; 1.3客户端&#xff08;浏览器&#xff09;实现 1.3.2 WebSocket对象的相关事宜&#xff1a; 1.3.3 WebSOcket方法 1.4 服务端实现 服务端如何接收客户端发送的请…

周志华深度森林deep forest(deep-forest)最新可安装教程,仅需在pycharm中完成,超简单安装教程

1、打开pycharm 没有pycharm的&#xff0c;在站内搜索安装教程即可。 2、点击“文件”“新建项目” 3、创建项目&#xff0c;Python版本中选择Python39。如果没有该版本&#xff0c;选择下面的Python 3.9下载并安装。 4、打开软件包&#xff0c;搜索“deep-forest”软件包&am…

ES 和Kibana-v2 带用户登录验证

1. 前言 ElasticSearch、可视化操作工具Kibana。如果你是Linux centos系统的话&#xff0c;下面的指令可以一路CV完成服务的部署。 2. 服务搭建 2.1. 部署ElasticSearch 拉取docker镜像 docker pull elasticsearch:7.17.21 创建挂载卷目录 mkdir /**/es-data -p mkdir /**/…

分布式kettle调度平台v6.4.0新功能介绍

介绍 Kettle&#xff08;也称为Pentaho Data Integration&#xff09;是一款开源的ETL&#xff08;Extract, Transform, Load&#xff09;工具&#xff0c;由Pentaho&#xff08;现为Hitachi Vantara&#xff09;开发和维护。它提供了一套强大的数据集成和转换功能&#xff0c…

力扣hot100-->排序

排序 1. 56. 合并区间 中等 以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好覆盖输入中的所有区间 。 示例 1&#xff1a; 输…

.net 8使用hangfire实现库存同步任务

C# 使用HangFire 第一章:.net Framework 4.6 WebAPI 使用Hangfire 第二章:net 8使用hangfire实现库存同步任务 文章目录 C# 使用HangFire前言项目源码一、项目架构二、项目服务介绍HangFire服务结构解析HangfireCollectionExtensions 类ModelHangfireSettingsHttpAuthInfoUs…

滑动窗口最大值(java)

题目描述 给你一个整数数组 nums&#xff0c;有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的最大值 。 示例 1&#xff1a; 输入&#xff1a;nums [1,3,-1,-3,5,3,6,7]…