【机器学习】—机器学习和NLP预训练模型探索之旅

目录

一.预训练模型的基本概念

1.BERT模型

2 .GPT模型

二、预训练模型的应用

1.文本分类

使用BERT进行文本分类

2. 问答系统

使用BERT进行问答

三、预训练模型的优化

 1.模型压缩

1.1 剪枝

权重剪枝

2.模型量化

2.1 定点量化

使用PyTorch进行定点量化

3. 知识蒸馏

3.1 知识蒸馏的基本原理

3.2 实例代码:使用知识蒸馏训练学生模型

四、结论


随着数据量的增加和计算能力的提升,机器学习和自然语言处理技术得到了飞速发展。预训练模型作为其中的重要组成部分,通过在大规模数据集上进行预训练,使得模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

一.预训练模型的基本概念

预训练模型是一种在大规模数据集上预先训练好的模型,可以作为其他任务的基础。预训练模型的优势在于其能够利用大规模数据集中的知识,提高模型的泛化能力和准确性。常见的预训练模型包括BERT(Bidirectional Encoder Representations from Transformers)、GPT(Generative Pre-trained Transformer)等。

1.BERT模型

BERT是由Google提出的一种双向编码器表示模型。BERT通过在大规模文本数据上进行掩码语言模型(Masked Language Model, MLM)和下一句预测(Next Sentence Prediction, NSP)的预训练,使得模型可以学习到深层次的语言表示。

2 .GPT模型

GPT由OpenAI提出,是一种基于Transformer的生成式预训练模型。GPT通过在大规模文本数据上进行自回归语言模型的预训练,使得模型可以生成连贯的文本。

二、预训练模型的应用

预训练模型在NLP领域有广泛的应用,包括但不限于文本分类、问答系统、机器翻译等。以下将介绍几个具体的应用实例。

1.文本分类

文本分类是将文本数据按照预定义的类别进行分类的任务。预训练模型可以通过在大规模文本数据上进行预训练,从而捕捉到丰富的语义信息,提高文本分类的准确性。

使用BERT进行文本分类

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)# 定义数据集
class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',return_attention_mask=True,return_tensors='pt',)return {'text': text,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 准备数据
texts = ["I love this!", "I hate this!"]
labels = [1, 0]
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.1)train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len=32)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len=32)train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(3):model.train()for batch in train_loader:optimizer.zero_grad()input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']outputs = model(input_ids=input_ids, attention_mask=attention_mask)_, predicted = torch.max(outputs.logits, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Validation Accuracy: {correct / total:.2f}')

2. 问答系统

问答系统是从文本中自动提取答案的任务。预训练模型可以通过在大规模问答数据上进行预训练,从而提高答案的准确性和相关性。

使用BERT进行问答

from transformers import BertForQuestionAnswering# 加载预训练的BERT问答模型
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')# 输入问题和上下文
question = "What is the capital of France?"
context = "Paris is the capital of France."# 编码输入
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')# 模型预测
outputs = model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits# 获取答案的起始和结束位置
start_idx = torch.argmax(start_scores)
end_idx = torch.argmax(end_scores) + 1# 解码答案
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx]))
print(f'Answer: {answer}')

三、预训练模型的优化

在实际应用中,预训练模型的优化至关重要。常见的优化方法包括模型压缩、量化和蒸馏等。

 1.模型压缩

模型压缩是通过减少模型参数数量和计算量来提高模型效率的方法。压缩后的模型不仅运行速度更快,还能减少存储空间和内存占用。常见的模型压缩技术包括剪枝、量化和知识蒸馏等。

1.1 剪枝

剪枝(Pruning)是一种通过删除模型中冗余或不重要的参数来减小模型大小的方法。剪枝可以在训练过程中或训练完成后进行。常见的剪枝方法包括:

  • 权重剪枝(Weight Pruning):删除绝对值较小的权重,认为这些权重对模型输出影响不大。
  • 结构剪枝(Structured Pruning):删除整个神经元或卷积核,减少模型的计算量和存储需求。

剪枝后的模型通常需要重新训练,以恢复或接近原始模型的性能。

权重剪枝
import torch
import torch.nn.utils.prune as prune# 定义一个简单的模型
class SimpleModel(torch.nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = torch.nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()# 对模型的全连接层进行权重剪枝
prune.l1_unstructured(model.fc, name='weight', amount=0.5)# 查看剪枝后的权重
print(model.fc.weight)

2.模型量化

模型量化是通过降低模型参数的精度来减少计算量的方法。量化通常通过将浮点数表示的权重和激活值转换为低精度表示(如8位整数)来实现。这可以显著减少模型的存储空间和计算开销,同时在硬件上加速模型推理。

2.1 定点量化

定点量化(Fixed-point Quantization)是将浮点数表示的权重和激活值转换为固定精度的整数表示。常见的定点量化包括8位整数量化(INT8),这种量化方法在不显著降低模型精度的情况下,可以大幅提升计算效率。

使用PyTorch进行定点量化
import torch
import torch.quantization# 加载预训练模型
model = SimpleModel()# 定义量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')# 准备量化模型
model = torch.quantization.prepare(model, inplace=True)# 模拟量化后的推理过程
# 这里应该使用训练数据对模型进行微调,但为了简单起见,省略此步骤
model = torch.quantization.convert(model, inplace=True)# 查看量化后的模型
print(model)

3. 知识蒸馏

知识蒸馏(Knowledge Distillation)是通过将大模型(教师模型,Teacher Model)的知识转移到小模型(学生模型,Student Model)的方法,从而提高小模型的性能和效率。知识蒸馏的核心思想是通过教师模型的软标签(soft labels)指导学生模型的训练。

3.1 知识蒸馏的基本原理

在知识蒸馏过程中,学生模型不仅学习训练数据的真实标签,还学习教师模型对训练数据的输出,即软标签。软标签包含了更多的信息,比如类别之间的相似性,使学生模型能够更好地泛化。

蒸馏损失函数通常由两部分组成:

  • 交叉熵损失:衡量学生模型输出与真实标签之间的差异。
  • 蒸馏损失:衡量学生模型输出与教师模型软标签之间的差异。

总体损失函数为这两部分的加权和。

3.2 实例代码:使用知识蒸馏训练学生模型

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset# 定义教师模型和学生模型
teacher_model = SimpleModel()
student_model = SimpleModel()# 加载示例数据
data = torch.randn(100, 10)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义蒸馏训练函数
def distillation_train(student_model, teacher_model, data_loader, optimizer, temperature=2.0, alpha=0.5):teacher_model.eval()student_model.train()for data, labels in data_loader:optimizer.zero_grad()# 教师模型输出with torch.no_grad():teacher_logits = teacher_model(data)# 学生模型输出student_logits = student_model(data)# 计算蒸馏损失loss_ce = F.cross_entropy(student_logits, labels)loss_kl = F.kl_div(F.log_softmax(student_logits / temperature, dim=1),F.softmax(teacher_logits / temperature, dim=1),reduction='batchmean') * (temperature ** 2)loss = alpha * loss_ce + (1.0 - alpha) * loss_klloss.backward()optimizer.step()# 定义优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)# 进行蒸馏训练
for epoch in range(10):distillation_train(student_model, teacher_model, data_loader, optimizer)# 验证学生模型
student_model.eval()
correct = 0
total = 0
with torch.no_grad():for data, labels in data_loader:outputs = student_model(data)_, predicted = torch.max(outputs, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Student Model Accuracy: {correct / total:.2f}')

四、结论

预训练模型在机器学习和自然语言处理领域具有重要意义。通过在大规模数据集上进行预训练,模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/329568.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS7安装Redis

安装Redis,并使用PHP连接Redis 一、准备工作 1、安装LNMP 参考:搭建LNMP服务器-CSDN博客文章浏览阅读876次,点赞14次,收藏4次。LNMP 架构通常用于构建高性能、可扩展的 Web 应用程序。Nginx 作为前端 Web 服务器,负…

正则表达式(知识总结篇)

本篇文章主要是针对初学者,对正则表达式的理解、作用和应用 正则表达式🌟 一、🍉正则表达式的概述二、🍉正则表达式的语法和使用三、 🍉正则表达式的常用操作符四、🍉re库主要功能函数 一、🍉正…

科技查新中医学科研项目查新点如何确立与提炼?案例讲解

一、前言 医学科技查新包括立项查新和成果查新两个部分,其中医学立项查新,它是指在医学科研项目申报开题之前,通过在一定范围内进行该课题的相关文献检索 ( 可以根据项目委托人的具体要求,进行国内检索或者进行国外检索 ) &#x…

深度学习之基于Matlab的BP神经网络交通标志识别

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 随着智能交通系统(ITS)的快速发展,交通标志识别&#xff0…

1941springboot VUE 服务机构评估管理系统开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 springboot VUE服务机构评估管理系统是一套完善的完整信息管理类型系统,结合springboot框架和VUE完成本系统,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发),系统具有完整的源代…

Python | Leetcode Python题解之第108题将有序数组转换为二叉搜索树

题目: 题解: class Solution:def sortedArrayToBST(self, nums: List[int]) -> TreeNode:def helper(left, right):if left > right:return None# 选择任意一个中间位置数字作为根节点mid (left right randint(0, 1)) // 2root TreeNode(nums…

linux命令中arj使用

arj 用于创建和管理.arj压缩包 补充说明 arj命令 是 .arj 格式的压缩文件的管理器,用于创建和管理 .arj 压缩包。 语法 arj(参数)参数 操作指令:对 .arj 压缩包执行的操作指令;压缩包名称:指定要操作的arj压缩包名称。 更多…

基于Matlab实现声纹识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 声纹识别,也称为说话人识别,是一种通过声音判别说话人身份的生物识别技…

不闭合三维TSP:蛇优化算法SO求解不闭合三维TSP(起点固定,终点不定,可以更改数据集),MATLAB代码

旅行商从城市1出发,终点城市由算法求解而定 部分代码 close all clear clc global data load(data.txt)%导入TSP数据集 Dimsize(data,1)-1;%维度 lb-100;%下界 ub100;%上界 fobjFun;%计算总距离 SearchAgents_no100; % 种群大小(可以修改) …

MySQL索引和视图

MySQL索引和视图是关系型数据库MySQL中的两个重要概念。索引用于优化数据库的查询性能,而视图用于提供一个逻辑上的表结构,方便用户查询和操作数据。 索引是一种数据结构,可以加速对数据库表中的数据进行查询的速度。通过创建索引&#xff0…

HTML用法介绍

文章目录 一、HTML概念和模版二、常用标签及用法1.p标签2.span标签3.h标签4.hr标签5.img标签6.a标签7.input标签8.table标签 一、HTML概念和模版 HTML的全称为超文本标记语言&#xff0c;它包括一系列标签组成&#xff0c;模版及各部分注释如下&#xff1a; <!--声明文档类…

轻量SEO分析报告程序网站已开心去授权

轻量SEO分析报告程序网站已开心去授权&#xff0c;可以让你生成有洞察力的、 简洁的、易于理解的SEO报告&#xff0c;帮助你的网页排名和表现更好 网站源码免费下载地址抄笔记 (chaobiji.cn)https://chaobiji.cn/

linux-配置服务器之间 ssh免密登录

前言 在管理多台Linux服务器时,为了方便操作和自动化任务,实现服务器之间的SSH免密登录是非常有必要的。SSH免密登录可以避免每次远程连接时输入密码,大大提高效率。本文将详细介绍SSH免密登录的原理和实现步骤。 一、原理解释 SSH免密登录的实现依赖于SSH密钥对,主要是利用…

企业知识库智能问答系统的实践

1、页面效果 PC端 2、页面效果 手机端 3、主要支持功能 新建会话 历史会话 2、智能问答 支持 文本分类和意图识别&#xff0c;支持基于大模型的对话理解&#xff0c;支持流式对话 3、支持手机端 语音识别 4、主要服务包括 向量库Milvus 向量计算和文本分类服务 …

Python 渗透测试:GhostScript 沙箱绕过.(CVE-2018-16509)

什么是 GhostScript 沙箱绕过 GhostScript 沙箱是一种安全机制,用于在受控环境中运行 GhostScript 解释器,以防止恶意代码的执行。GhostScript 是一个广泛使用的 PDF 和 PostScript 解释器,通常用于在服务器上处理和渲染这些文件格式。Tavis Ormandy 通过公开邮件列表&#xf…

20232803 2023-2024-2 《网络攻防实践》实践十报告

目录 1. 实践内容1.1 SEED SQL注入攻击与防御实验1.2 SEED XSS跨站脚本攻击实验(Elgg) 2. 实践过程2.1 SEED SQL注入攻击与防御实验2.1.1 熟悉SQL语句2.1.2 对SELECT语句的SQL注入攻击2.1.3 对UPDATE语句的SQL注入攻击2.1.4 SQL对抗 2.2 SEED XSS跨站脚本攻击实验(Elgg)2.2.1 发…

Elasticsearch的Index sorting 索引预排序会导致索引数据的移动吗?

索引预排序可以确保索引数据按照指定字段的指定顺序进行存储&#xff0c;这样在查询的时候&#xff0c;如果固定使用这个字段进行排序就可以加快查询效率。 我们知道数据写入的过程中&#xff0c;如果需要确保数据有序&#xff0c;可能需要在原数据的基础上插入新的数据&#…

[机缘参悟-185] - 《道家-水木然人间清醒1》读书笔记 - 真相本质 -8- 认知觉醒 - 逻辑谬误、认知偏差:幸存者偏差

目录 前言&#xff1a; 一、幸存者偏差 二、幸存者偏差在现实中的应用 第一个故事&#xff1a; 第二个故事&#xff1a; 三、生活中的幸存者偏差 四、迷恋成功者经验的原因&#xff1a;鸡汤、幻想、传奇、希望 备注&#xff1a; 前言&#xff1a; 幸存者偏差&#xff0…

Java 多线程抢红包

问题需求 一个人在群里发了1个100元的红包&#xff0c;被分成了8个&#xff0c;群里有10个人一起来抢红包&#xff0c;有抢到的金额随机分配。 红包功能需要满足哪些具体规则呢? 1、被分的人数抢到的金额之和要等于红包金额&#xff0c;不能多也不能少。 2、每个人至少抢到1元…

免费发布web APP的四个途径(Python和R)

免费发布数据分析类&#x1f310;web APP的几个途径&#x1f4f1; 数据分析类web APP目前用来部署生信工具&#xff0c;统计工具和预测模型等&#xff0c;便利快捷&#xff0c;深受大家喜爱。而一个免费的APP部署途径&#xff0c;对于开发和测试APP都是必要的。根据笔者的经验…