NLP(16)--生成式任务

前言

仅记录学习过程,有问题欢迎讨论

  输入输出均为不定长序列(seq2seq)

自回归语言模型:

  • x 为 str[start : end ]; y为 [start+1 : end +1] 同时训练多个字,逐字计算交叉熵

encode-decode结构:

  • Encoder将输入转化为向量或矩阵,其中包含了输入中的信息
  • Decoder将Encoder的输出转化为输出

attention机制

  • 输入和输出应该和重点句子强相关,给输入加权(所以维度应该和输入的size一致)
  • 在这里插入图片描述

Teacher forcing

  • 使用真实标签作为下一个输入(自回归语言模型就是使用的teacher forcing)

Transform结构

  • Query来自Decode ,KV来自Encode
  • 在这里插入图片描述

使用Mask Attation 来避免对output做计算时,获取了所有的信息。只使用当前的位置对应的output信息。(自回归模型,先mask,然后在softmax)
在这里插入图片描述

评价指标:

  • BLEU:按照输出的字符计算一系列的数学(惩罚机制,Ngrim)计算来评价相似性

采样:

  • Beam size:
    保留概率最大的n条路径

  • Temperature Sampling
    根据概率分布生成下一个词,通过参数T,T越大,结果越随机,分布更均匀

  • TOP-P/K
    采样先按概率从大到小排序,累加概率不超过P的范围中选
    采样从TOP-K中采样下一个词

代码

使用bert实现自回归训练模型,
添加mask attention 来实现

# coding:utf8import torch
import torch.nn as nn
import numpy as np
import math
import random
import os
import refrom transformers import BertModel, BertTokenizer"""
基于pytorch的LSTM语言模型
"""class LanguageModel(nn.Module):def __init__(self, input_dim, vocab_size):super(LanguageModel, self).__init__()# self.embedding = nn.Embedding(len(vocab), input_dim)# self.layer = nn.LSTM(input_dim, input_dim, num_layers=1, batch_first=True)self.bert = BertModel.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, vocab_size)# self.dropout = nn.Dropout(0.1)self.loss = nn.functional.cross_entropy# 当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# x = self.embedding(x)  # output shape:(batch_size, sen_len, input_dim)# 使用mask来防止提前预知结果if y is not None:# 构建一个下三角的mask# bert的mask attention 为(batch_size, vocab_size, vocab_size) L*Lmask = torch.tril(torch.ones(x.shape[0], x.shape[1], x.shape[1]))print(mask)x, _ = self.bert(x, attention_mask=mask)y_pred = self.classify(x)return self.loss(y_pred.view(-1, y_pred.shape[-1]), y.view(-1))else:x = self.bert(x)[0]y_pred = self.classify(x)return torch.softmax(y_pred, dim=-1)# 加载字表
def build_vocab(vocab_path):vocab = {"<pad>": 0}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):char = line[:-1]  # 去掉结尾换行符vocab[char] = index + 1  # 留出0位给pad tokenreturn vocab# 加载语料
def load_corpus(path):corpus = ""with open(path, encoding="utf8") as f:for line in f:corpus += line.strip()return corpus# 随机生成一个样本
# 从文本中截取随机窗口,前n个字作为输入,最后一个字作为输出
def build_sample(tokenizer, window_size, corpus):start = random.randint(0, len(corpus) - 1 - window_size)end = start + window_sizewindow = corpus[start:end]target = corpus[start + 1:end + 1]  # 输入输出错开一位# print(window, target)# 中文的文本转化为tokenizer的idinput_ids_x = tokenizer.encode(window, add_special_tokens=False, padding='max_length', truncation=True,max_length=10)input_ids_y = tokenizer.encode(target, add_special_tokens=False, padding='max_length', truncation=True,max_length=10)return input_ids_x, input_ids_y# 建立数据集
# sample_length 输入需要的样本数量。需要多少生成多少
# vocab 词表
# window_size 样本长度
# corpus 语料字符串
def build_dataset(sample_length, tokenizer, window_size, corpus):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(tokenizer, window_size, corpus)dataset_x.append(x)dataset_y.append(y)return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)# 建立模型
def build_model(vocab_size, char_dim):model = LanguageModel(char_dim, vocab_size)return model# 文本生成测试代码
def generate_sentence(openings, model, tokenizer, window_size):# reverse_vocab = dict((y, x) for x, y in vocab.items())model.eval()with torch.no_grad():pred_char = ""# 生成文本超过30字终止while len(openings) <= 30:openings += pred_charx = tokenizer.encode(openings, add_special_tokens=False, padding='max_length', truncation=True,max_length=10)x = torch.LongTensor([x])if torch.cuda.is_available():x = x.cuda()# batch_size = 1 最后一个字符的概率y = model(x)[0][-1]index = sampling_strategy(y)# 转化为中文 只有一个字符pred_char = tokenizer.decode(index)return openings# 采样方式
def sampling_strategy(prob_distribution):if random.random() > 0.1:strategy = "greedy"else:strategy = "sampling"if strategy == "greedy":return int(torch.argmax(prob_distribution))elif strategy == "sampling":prob_distribution = prob_distribution.cpu().numpy()return np.random.choice(list(range(len(prob_distribution))), p=prob_distribution)# 计算文本ppl
def calc_perplexity(sentence, model, vocab, window_size):prob = 0model.eval()with torch.no_grad():for i in range(1, len(sentence)):start = max(0, i - window_size)window = sentence[start:i]x = [vocab.get(char, vocab["<UNK>"]) for char in window]x = torch.LongTensor([x])target = sentence[i]target_index = vocab.get(target, vocab["<UNK>"])if torch.cuda.is_available():x = x.cuda()pred_prob_distribute = model(x)[0][-1]target_prob = pred_prob_distribute[target_index]prob += math.log(target_prob, 10)return 2 ** (prob * (-1 / len(sentence)))def train(corpus_path, save_weight=True):epoch_num = 15  # 训练轮数batch_size = 64  # 每次训练样本个数train_sample = 10000  # 每轮训练总共训练的样本总数char_dim = 768  # 每个字的维度window_size = 10  # 样本文本长度# vocab = build_vocab(r"vocab.txt")  # 建立字表tokenizer = BertTokenizer.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese")vocab_size = 21128corpus = load_corpus(corpus_path)  # 加载语料model = build_model(vocab_size, char_dim)  # 建立模型if torch.cuda.is_available():model = model.cuda()optim = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立优化器print("文本词表模型加载完毕,开始训练")for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, tokenizer, window_size, corpus)  # 构建一组训练样本if torch.cuda.is_available():x, y = x.cuda(), y.cuda()optim.zero_grad()  # 梯度归零loss = model(x, y)  # 计算lossloss.backward()  # 计算梯度optim.step()  # 更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))print(generate_sentence("忽然一阵狂风吹过,他直接", model, tokenizer, window_size))print(generate_sentence("天青色等烟雨,而我在", model, tokenizer, window_size))if not save_weight:returnelse:base_name = os.path.basename(corpus_path).replace("txt", "pth")model_path = os.path.join("model", base_name)torch.save(model.state_dict(), model_path)returnif __name__ == "__main__":train("corpus.txt", False)# mask = torch.tril(torch.ones(4, 4)).unsqueeze(0).unsqueeze(0)# print(mask)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/329730.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务远程调用 RestTemplate

Spring给我们提供了一个RestTemplate的API&#xff0c;可以方便的实现Http请求的发送。 同步客户端执行HTTP请求&#xff0c;在底层HTTP客户端库(如JDK HttpURLConnection、Apache HttpComponents等)上公开一个简单的模板方法API。RestTemplate通过HTTP方法为常见场景提供了模…

从ES5迈向ES6:探索 JavaScript 新增声明命令与解构赋值的魅力

个人主页&#xff1a;学习前端的小z 个人专栏&#xff1a;JavaScript 精粹 本专栏旨在分享记录每日学习的前端知识和学习笔记的归纳总结&#xff0c;欢迎大家在评论区交流讨论&#xff01; ES5、ES6介绍 文章目录 &#x1f4af;声明命令 let、const&#x1f35f;1 let声明符&a…

【LeetCode】每日一题 2024_5_24 找出最具竞争力的子序列(栈,模拟,贪心)

文章目录 LeetCode&#xff1f;启动&#xff01;&#xff01;&#xff01;题目&#xff1a;找出最具竞争力的子序列题目描述代码与解题思路 每天进步一点点 LeetCode&#xff1f;启动&#xff01;&#xff01;&#xff01; 题目&#xff1a;找出最具竞争力的子序列 题目链接&a…

【Unity2D:C#Script】实现角色射击功能

一、创建子弹预制体 1. 创建子弹预制体 2. 调整图片大小、层级 二、为子弹添加碰撞体积 1. 添加Box Collider 2D、Rigidbody 2D组件 2. 锁定z轴 三、编辑敌人脚本 注&#xff1a;在以下代码中&#xff0c;只显示本章节新增的代码&#xff0c;省略原有的代码 1. 为敌人添加生…

一阶数字高通滤波器

本文的主要内容包含一阶高通滤波器公式的推导和数字算法的实现以及编程和仿真 1 计算公式推导 1.1.2 算法实现及仿真 利用python实现的代码如下&#xff1a; import numpy as np # from scipy.signal import butter, lfilter, freqz import matplotlib.pyplot as plt #2pifW…

免费分享一套微信小程序旅游推荐(智慧旅游)系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】,帅呆了~~

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序旅游推荐(智慧旅游)系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序旅游推荐(智慧旅游)系统(SpringBoot后端Vue管理端) Java毕业设计…

视频监控管理平台LntonCVS监控视频汇聚融合云平台主要功能应用场景介绍

随着网络技术的不断发展和万物互联时代的到来&#xff0c;视频融合在一些系统集成项目及综合管理应用中变得日益重要。本文以LntonCVS视频融合云平台为案例&#xff0c;探讨视频融合的对象及其应用场景。 1. 视频监控设备 视频监控摄像设备是各种视频应用项目的基础部分。在视…

亚马逊卖家账号注册复杂吗?需要什么辅助工具吗?

在当今数字化的商业世界中&#xff0c;亚马逊作为全球最大的电商平台之一&#xff0c;吸引着无数的卖家和买家。对于想要进入亚马逊销售市场的卖家来说&#xff0c;首先要完成的一项重要任务就是注册亚马逊卖家账号。本文将详细介绍亚马逊注册的步骤、所需时间&#xff0c;以及…

入门四认识HTML

一、HTML介绍 1、Web前端三大核心技术 HTML&#xff1a;负责网页的架构 CSS&#xff1a;负责网页的样式、美化 JS&#xff1a;负责网页的行动 2、什么是HTML HTML是用来描述网页的一种语言。 3、Html标签 单标签<html> 双标签<h>内容</h> 4、标…

【译】组复制和 Percona XtraDB 集群: 常见操作概述

原文地址&#xff1a;Group Replication and Percona XtraDB Cluster: Overview of Common Operations 在这篇博文中&#xff0c;我将概述使用 MySQL Group Replication 8.0.19&#xff08;又称 GR&#xff09;和 Percona XtraDB Cluster 8 (PXC)&#xff08;基于 Galera&…

服务器数据恢复—EVA存储多块硬盘离线导致部分LUN丢失的数据恢复案例

服务器数据恢复环境&#xff1a; 1台某品牌EVA4400控制器3台EVA4400扩展柜28块FC硬盘。 服务器故障&#xff1a; 由于两块磁盘掉线导致存储中某些LUN不可用&#xff0c;某些LUN丢失&#xff0c;导致存储崩溃。 服务器数据恢复过程&#xff1a; 1、由于EVA4400存储故障是某些磁…

Java 对外API接口开发 java开发api接口如何编写

Java API API&#xff08;Application Programming Interface&#xff09;是指应用程序编程接口&#xff0c;的JavaAPI是指JDK提供的各种功能的Java类 String类 String类的初始化&#xff1a; &#xff08;1&#xff09;使用字符串常量直接初始化 初始化&#xff1a;String s…

闲话 .NET(4):为什么要跨平台?

前言 .NET Core 有一个关键词就是跨平台&#xff0c;为什么要跨平台呢&#xff1f;Windows 操作系统不香吗&#xff1f;今天我们来聊聊这个 原因一&#xff1a;安全考虑 Windows OS 是闭源的&#xff0c;而 Linux 是开源的&#xff0c;因此有些公司的技术负责人就认为 Linux…

笔记:weblogic配置内存启动参数

可以在控制台配置 参数值 -Xms2048m -Xmx2048m -XX:PermSize512m -XX:MaxPermSize512m -Xss128k激活更改。修改完之后&#xff0c;节点需要重启才能生效。 参数说明&#xff1a; -Xms 为JVM启动时分配的内存 -Xmx 为JVM运行过程中分配的最大内存 -XX:PermSize 为JVM初始分配…

Qt笔记:动态处理多个按钮点击事件以更新UI

问题描述 在开发Qt应用程序时&#xff0c;经常需要处理多个按钮的点击事件&#xff0c;并根据点击的按钮来更新用户界面&#xff08;UI&#xff09;&#xff0c;如下图。例如&#xff0c;你可能有一个包含多个按钮的界面&#xff0c;每个按钮都与一个文本框和一个复选框相关联…

python从0开始学习(十二)

目录 前言 1、字符串的常用操作 2、字符串的格式化 2.1 格式化字符串的详细格式&#xff08;针对format形式&#xff09; ​编辑 总结 前言 上一篇文章我们讲解了两道关于组合数据类型的题目&#xff0c;本篇文章我们将学习新的章节&#xff0c;学习字符串及正则表达式。 …

react实现把pc网站快捷添加到桌面快捷方式

文章目录 1. 需求2. 实现效果3. 核心逻辑4. 完整react代码 1. 需求 这种需求其实在国外一些游戏网站和推广网站中经常会用到&#xff0c;目的是为了让客户 快捷方便的保存网站到桌面 &#xff0c;网站主动尽量避免下次找不到网站地址了&#xff0c;当然精确的客户自己也可以使…

打印安全:防止打印过程中的商业机密泄露

在数字化办公日益普及的今天&#xff0c;打印安全常常成为企业信息保护中被忽视的一环。商业机密在打印过程中泄露&#xff0c;可能会给企业带来巨大的损失。本文将探讨如何通过一系列措施&#xff0c;确保打印过程中的商业机密安全。 一、打印安全的重要性 打印设备作为企业中…

hbase版本从1.2升级到2.1 spark读取hive数据写入hbase 批量写入类不存在问题

在hbase1.2版本中&#xff0c;pom.xml中引入hbase-server1.2…0和hbase-client1.2.0就已经可以有如下图的类。但是在hbase2.1.0版本中增加这两个不行。hbase-server2.1.0中没有mapred包&#xff0c;同时mapreduce下就2个类。版本已经不支持。 <dependency><groupId>…

两步将 CentOS 6.0 原地升级并迁移至 RHEL 7.9

《OpenShift / RHEL / DevSecOps 汇总目录》 说明 本文介绍如何将一个 CentOS 6.0 的系统升级并转换迁移到 RHEL 7.9。 本文是《在离线环境中将 CentOS 7.X 原地升级并迁移至 RHEL 7.9》阶进篇。 所有被测软件的验证操作可参见上述前文中对应章节的说明。 准备 CentOS 6.…