[GPT-1]论文实现:Improving Language Understanding by Generative Pre-Training

文章目录

    • 一、完整代码
    • 二、论文解读
      • 2.1 GPT架构
      • 2.2 GPT的训练方式
        • Unsupervised pre_training
        • Supervised fine_training
    • 三、过程实现
      • 3.1 导包
      • 3.2 数据处理
      • 3.3 模型构建
      • 3.4 模型配置
    • 四、整体总结

论文:Improving Language Understanding by Generative Pre-Training
作者:Alec Radford,Karthik Narasimhan,Tim Salimans,Ilya Sutskever
时间:2018

一、完整代码

这里我们使用tensorflow代码进行实现

# 完整代码在这里
import tensorflow as tf
import keras_nlp
import jsondef get_merges():with open('./data/GPT_merges.txt') as f:merges = f.read().split('\n')return mergesmerges = get_merges()
vocabulary = json.load(open('./data/GPT_vocab.json'))tokenizer = keras_nlp.tokenizers.BytePairTokenizer(vocabulary=vocabulary,merges=merges
)pad = tokenizer.vocabulary_size()
start = tokenizer.vocabulary_size() + 1
end = tokenizer.vocabulary_size() + 2corpus = open('./data/shakespeare.txt').read()
data = tokenizer(corpus)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(63, drop_remainder=True)
def process_data(x):x = tf.concat([tf.constant(start)[tf.newaxis], x, tf.constant(end)[tf.newaxis]], axis=-1)return x[:-1], x[1:]dataset = dataset.map(process_data).batch(16)inputs, outputs = dataset.take(1).get_single_element()class GPT(tf.keras.Model):def __init__(self, vocabulary_size, sequence_length, embedding_dim, num_layers, intermediate_dim, num_heads, dropout=0.1):super().__init__()self.embedding = keras_nlp.layers.TokenAndPositionEmbedding(vocabulary_size=vocabulary_size,sequence_length=sequence_length,embedding_dim=embedding_dim,)self.lst = [keras_nlp.layers.TransformerDecoder(intermediate_dim=intermediate_dim,num_heads=num_heads,dropout=dropout,) for _ in range(num_layers)]self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')def call(self, x):decoder_padding_mask = x!= 0 output = self.embedding(x)for item in self.lst:output = item(output, decoder_padding_mask=decoder_padding_mask)output = self.dense(output)return outputvocabulary_size = tokenizer.vocabulary_size() + 3
sequence_length= 64
embedding_dim=512
num_layers=12
intermediate_dim=1024 
num_heads=8gpt = GPT(vocabulary_size, sequence_length, embedding_dim, num_layers, intermediate_dim, num_heads)gpt(inputs)
gpt.summary()def masked_loss(label, pred):mask = label != padloss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction='none')loss = loss_object(label, pred)mask = tf.cast(mask, dtype=loss.dtype)loss *= maskloss = tf.reduce_sum(loss)/tf.reduce_sum(mask)return lossdef masked_accuracy(label, pred):pred = tf.argmax(pred, axis=2)label = tf.cast(label, pred.dtype)match = label == predmask = label != padmatch = match & maskmatch = tf.cast(match, dtype=tf.float32)mask = tf.cast(mask, dtype=tf.float32)return tf.reduce_sum(match)/tf.reduce_sum(mask)gpt.compile(loss=masked_loss,optimizer='adam',metrics=[masked_accuracy]
)gpt.fit(dataset, epochs=3)

二、论文解读

GPT全称为Generative Pre-Training,即生成式的预训练模型;

2.1 GPT架构

其模型架构非常简单,就是Transformerdecoder修正后的叠加,因为这是文本生成任务,并没有类似于seq2seq翻译模型的对应句子,GPT的处理方式是直接把Transformer中的decoder中的CrossAtention直接删除;

如图所示:蓝色方框部分为Transformerdecoder层,其中红色方框部分为被删除的多头注意力层;

得到的模型如下:

是不是特别简单;

2.2 GPT的训练方式

首先要声明的是GPT采用的是semi-supervised即半监督学习方法,其本质是一个两阶段的训练过程,第一阶段是无监督学习,就是单纯的利用Transformerdecoder来做预测下一个词的任务;第二阶段是有监督学习,利用带标签的语料信息对模型进行训练;

接下来对这两个过程进行详细的分析;

Unsupervised pre_training

原文如图所示:

其根本目的是最大化语言模型的极大似然估计,其本质就是一个链式法则取对数;

L 1 ( u ) = l o g ( P ( u i , u i − 1 , … , u 1 ) ) P ( u i , u i − 1 , … , u 1 ) = P ( u 1 ) ⋅ P ( u 2 ∣ u 1 ) ⋅ P ( u 3 ∣ u 2 , u 1 ) ⋅ ⋅ ⋅ P ( u i ∣ u i − 1 , … , u 1 ) \begin{aligned} & L_1(u) = log(P(u_i,u_{i-1},\dots,u_1)) \\ \\ & P(u_i,u_{i-1},\dots,u_1) = P(u_1)·P(u_2|u_1)·P(u_3|u_2,u_1)···P(u_i|u_{i-1},\dots,u_1) \end{aligned} L1(u)=log(P(ui,ui1,,u1))P(ui,ui1,,u1)=P(u1)P(u2u1)P(u3u2,u1)⋅⋅⋅P(uiui1,,u1)

而下面计算 P P P 的过程,就是利用 mask 的机制来制造类似于RNN的过程;

如果对注意力机制不理解的,可以去看一下Attention Is All You Need这篇论文,我也在其他博客中简单介绍了一下;

Supervised fine_training

原文如图所示:

unsupervised pre_training不同的是,其去掉了最后一层的 W e W_e We换成了一个新的参数 W y W_y Wy,利用新的参数去预测新的标签;这里我的理解是这样的,在unsupervised pre_training中,我们相当于在大炮不停调整弹药量,大炮的对准方向 W e W_e We也在不停的向下一个单词调整;当弹药合理时,方向正确时,我们调整大炮方向去攻打supervised fine_tuning

这里的目标函数进行了一次正则化处理,避免一味的调整方向而忽略了弹药量;

L 3 ( C ) = L 2 ( C ) + λ L 1 ( C ) L_3(C) = L_2(C) + \lambda L_1(C) L3(C)=L2(C)+λL1(C)

至此,模型的训练就结束了;

三、过程实现

3.1 导包

这里使用tensorflowkeras_nlpjson三个包进行过程实现;

import tensorflow as tf
import keras_nlp
import json

3.2 数据处理

第一部分是无监督训练,我们需要导入一段长文本构建数据集进行训练即可,这里我们使用莎士比亚的作品 storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt;
第二部分是有监督训练,我们可以使用CoLA语料进行文本分类,CoLA语料来自GLUE Benchmark中的The Corpus of Linguistic Acceptability

def get_merges():with open('./data/GPT_merges.txt') as f:merges = f.read().split('\n')return mergesmerges = get_merges()
vocabulary = json.load(open('./data/GPT_vocab.json'))tokenizer = keras_nlp.tokenizers.BytePairTokenizer(vocabulary=vocabulary,merges=merges
)pad = tokenizer.vocabulary_size()
start = tokenizer.vocabulary_size() + 1
end = tokenizer.vocabulary_size() + 2corpus = open('./data/shakespeare.txt').read()
data = tokenizer(corpus)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(63, drop_remainder=True)
def process_data(x):x = tf.concat([tf.constant(start)[tf.newaxis], x, tf.constant(end)[tf.newaxis]], axis=-1)return x[:-1], x[1:]dataset = dataset.map(process_data).batch(16)inputs, outputs = dataset.take(1).get_single_element()
# inputs
# <tf.Tensor: shape=(16, 64), dtype=int32, numpy=
# array([[50258,  5962,   220, ..., 14813,   220,  1462],
#        [50258,   220, 44769, ...,   220,   732,   220],
#        [50258, 16275,   470, ...,   220,  1616,   220],
#        ...,
#        [50258,   220,  1350, ...,   220, 19205,   198],
#        [50258,   271,   220, ...,    54, 18906,   220],
#        [50258, 10418,   268, ...,    40,  2937,    25]])>

3.3 模型构建

在这里构建模型:

class GPT(tf.keras.Model):def __init__(self, vocabulary_size, sequence_length, embedding_dim, num_layers, intermediate_dim, num_heads, dropout=0.1):super().__init__()self.embedding = keras_nlp.layers.TokenAndPositionEmbedding(vocabulary_size=vocabulary_size,sequence_length=sequence_length,embedding_dim=embedding_dim,)self.lst = [keras_nlp.layers.TransformerDecoder(intermediate_dim=intermediate_dim,num_heads=num_heads,dropout=dropout,) for _ in range(num_layers)]self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')def call(self, x):decoder_padding_mask = x!= 0 output = self.embedding(x)for item in self.lst:output = item(output, decoder_padding_mask=decoder_padding_mask)output = self.dense(output)return outputvocabulary_size = tokenizer.vocabulary_size() + 3
sequence_length= 64
embedding_dim=512
num_layers=12
intermediate_dim=1024 
num_heads=8gpt = GPT(vocabulary_size, sequence_length, embedding_dim, num_layers, intermediate_dim, num_heads)gpt(inputs)
gpt.summary()

构建模型结构如下:

3.4 模型配置

模型配置如下:

def masked_loss(label, pred):mask = label != padloss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction='none')loss = loss_object(label, pred)mask = tf.cast(mask, dtype=loss.dtype)loss *= maskloss = tf.reduce_sum(loss)/tf.reduce_sum(mask)return lossdef masked_accuracy(label, pred):pred = tf.argmax(pred, axis=2)label = tf.cast(label, pred.dtype)match = label == predmask = label != padmatch = match & maskmatch = tf.cast(match, dtype=tf.float32)mask = tf.cast(mask, dtype=tf.float32)return tf.reduce_sum(match)/tf.reduce_sum(mask)gpt.compile(loss=masked_loss,optimizer='adam',metrics=[masked_accuracy]
)gpt.fit(dataset, epochs=3)

训练过程不知道为什么masked_accuracy一直不变,需要分析;

四、整体总结

模型结构很简单,但是在实现过程中出现了和Bert一样的问题;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/212650.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一对一聊天

服务端 package 一对一用户; import java.awt.BorderLayout; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.net.ServerSocket; import java.net.Socket; import java.util…

GateWay网关介绍以及整合knife4j聚合所有服务的接口文档

为什么使用网关&#xff1f; 因为多个微服务的端口不同&#xff0c;前端调用不方便&#xff0c;使用网关可以统一接收处理前端的请求&#xff0c;同时方便接口的集中处理&#xff0c;比如鉴权、聚合接口文档、限流等等.. 这里使用Knife4j文档工具来实现接口文档&#xff1a;K…

ArcMap中构建金字塔详解

1.金字塔 1.1 定义 金字塔可用于改善性能。它们是原始栅格数据集的缩减采样版本&#xff0c;可包含多个缩减采样图层。金字塔的各个连续图层均以 2:1 的比例进行缩减采样。如下图所示。从金字塔的底层开始每四个相邻的像素经过重采样生成一个新的像素&#xff0c;依此重复进行…

C#winform上下班打卡系统Demo

C# winform上下班打卡系统Demo 系统效果如图所示 7个label控件(lblUsername、lblLoggedInEmployeeId、lab_IP、lblCheckOutTime、lblCheckInTime、lab_starttime、lab_endtime)、3个按钮、1个dataGridView控件、2个groupBox控件 C#代码实现 using System; using System.Dat…

【分布式微服务专题】从单体到分布式(二、SpringCloud整合Nacos)

目录 前言阅读对象阅读导航前置知识笔记正文一、下载安装二、项目整合2.1 服务注册与发现2.2 动态配置管理 三、其他实验四、服务之间的调用 学习总结感谢 前言 本篇笔记主要是记录我整合Nacos项目进来的过程。以实现服务注册发现&#xff0c;以及分布式配置管理。关于Nacos&a…

友思特分享 | Neuro-T:零代码自动深度学习训练平台

工业自动化、智能化浪潮涌进&#xff0c;视觉技术在其中扮演了至关重要的角色。在汽车、制造业、医药、芯片、食品等行业&#xff0c;基于视觉技术实现的缺陷检测具有非常大的需求。对于传统检测方法&#xff0c;目视检查方法能够有效检测非标、具有挑战性的缺陷&#xff0c;传…

在线教育小程序如何一键生成App

在线教育行业是指通过互联网平台提供的各种教育和培训服务。这不仅包括传统的课程学习&#xff0c;还涵盖了一系列创新的学习模式。例如&#xff0c;同步在线课程允许学生和教师在同一时间在线&#xff0c;通过实时的视频和聊天工具进行互动。而异步在线课程则为学生提供了更大…

2023年5个自动化EDA库推荐

EDA或探索性数据分析是一项耗时的工作&#xff0c;但是由于EDA是不可避免的&#xff0c;所以Python出现了很多自动化库来减少执行分析所需的时间。EDA的主要目标不是制作花哨的图形或创建彩色的图形&#xff0c;而是获得对数据集的理解&#xff0c;并获得对变量之间的分布和相关…

【Ambari】Python调用Rest API 获取YARN HA状态信息并发送钉钉告警

&#x1f984; 个人主页——&#x1f390;开着拖拉机回家_Linux,大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&am…

Java项目学生管理系统一前后端环境搭建

在现代的软件开发中&#xff0c;学生管理系统是一个常见的应用场景。通过学生管理系统&#xff0c;学校能够方便地管理学生的信息、课程安排和成绩等数据。本文将介绍如何使用Java语言搭建一个学生管理系统的前后端环境&#xff0c;并提供一个简单的示例。 1.环境搭建 学生管…

web:[GXYCTF2019]BabyUpload(文件上传、一句话木马、文件过滤)

题目 页面显示为文件上传 随便上传一个文件看看 上传一个文本文件显示 上传了一个图片显示 上传包含一句话木马的图片 上传了一个包含php一句话木马的文件&#xff0c;显示如上 换一个写法 上传成功 尝试上传.htaccess&#xff0c;上传失败&#xff0c;用抓包修改文件后缀 …

No suitable driver found for jdbc:mysql://localhost:3306(2023/12/7更新)

有两种情况&#xff1a; 压根没安装下载了但没设为库或方法不对 大多数为第一种情况&#xff1a; 一. 下载jdbc 打开网址选择一个版本进行下载 https://nowjava.com/jar/version/mysql/mysql-connector-java.html 二.安装jdbc 在项目里建一个lib文件夹 在把之前下载的jar文…

【Vulnhub 靶场】【Funbox: GaoKao】【简单】【20210606】

1、环境介绍 靶场介绍&#xff1a;https://www.vulnhub.com/entry/funbox-gaokao,707/ 靶场下载&#xff1a;https://download.vulnhub.com/funbox/FunboxGaoKao.ova 靶场难度&#xff1a;简单 发布日期&#xff1a;2021年06月06日 文件大小&#xff1a;1.3 GB 靶场作者&#…

Linux 多进程并发设计-进程对核的亲缘设置

1设计结构 2 设计优点 1 充分利用多核系统的并发处理能力2 负载均衡3 职责明确&#xff0c;管理进程仅负责管理&#xff0c;工作进程仅负责处理业务逻辑 3 演示代码: //main.cpp #define _GNU_SOURCE #include<sys/types.h> #include<sys/wait.h> #include <…

智能优化算法应用:基于金枪鱼群算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于金枪鱼群算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于金枪鱼群算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.金枪鱼群算法4.实验参数设定5.算法结果6.参考…

代码混淆技术探究与工具选择

代码混淆技术探究与工具选择 引言 在软件开发中&#xff0c;保护程序代码的安全性是至关重要的一环。代码混淆&#xff08;Obfuscated code&#xff09;作为一种常见的保护手段&#xff0c;通过将代码转换成难以理解的形式来提升应用被逆向破解的难度。本文将介绍代码混淆的概…

JAVA后端自学技能实操合集

JAVA后端自学技能实操 内容将会持续更新中,有需要添加什么内容可以再评论区留言,大家一起学习FastDFS使用docker安装FastDFS(linux)集成到springboot项目中 内容将会持续更新中,有需要添加什么内容可以再评论区留言,大家一起学习 FastDFS 组名&#xff1a;文件上传后所在的 st…

配置集群免密登录

文章目录 前言配置集群免密登录1. 设置主机名与 IP 地址的映射关系2. 生成 SSH 密钥对3. 将公钥复制到集群节点4. 测试免密登录5. 配置节点之间互相免密登录 总结 前言 本文介绍了如何配置集群之间免密登录&#xff0c;以便在搭建集群环境时方便地进行节点之间的通信。通过设置…

分布式搜索引擎elasticsearch(二)

1.DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1.DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询。常见的查询类型包括: 查询所有:查询出所有数据,一般测试用。例如:match_all 全文检索(full text)查…

编程过程中出现bug如何应对?

编程过程中出现bug如何应对&#xff1f; 1.找错误原因 如果完全不知道出错的原因&#xff0c;或者说存在着很多错误的有原因&#xff0c;----》控制变量法 例如&#xff0c;昨天我在使用torchrun 多卡并行一个程序的时候&#xff0c;出现了大量的bug, 于是我将报错信息放在网…