240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

今天基于GPT实现一个情感分类的功能,假设已经安装好了MindSpore环境。

# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com

导包导包

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# 加载IMDb数据集
imdb_ds = load_dataset('imdb', split=['train', 'test'])
# 获取训练集
imdb_train = imdb_ds['train']
# 获取测试集
imdb_test = imdb_ds['test']
# 调用get_dataset_size方法来获取训练集的大小
imdb_train.get_dataset_size()
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
# 导入来自mindnlp库transformers模块中的GPTTokenizer类
from mindnlp.transformers import GPTTokenizer# 初始化GPT分词器,使用预训练的'openai-gpt'模型
# 分词器
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# 定义一个特殊token字典,包括开始、结束和填充token
special_tokens_dict = {"bos_token": "<bos>",  # 开始符号"eos_token": "<eos>",  # 结束符号"pad_token": "<pad>",  # 填充符号
}# 向分词器中添加特殊token,并返回添加的token数量
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# 将训练数据集imdb_train分割成训练集和验证集
# 按照70%训练集和30%验证集的比例进行划分
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 调用create_tuple_iterator方法创建一个迭代器,并通过next函数获取迭代器的第一个元素
# 这里的目的是为了展示或测试迭代器是否能正常生成数据
# 对于参数和返回值的详细说明,需要查看create_tuple_iterator方法的文档或实现
next(dataset_train.create_tuple_iterator())
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 执行模型训练
trainer.run(tgt_columns="labels")
# 初始化Evaluator对象,用于评估模型性能
# 参数说明:
# network: 待评估的模型
# eval_dataset: 用于评估的测试数据集
# metrics: 评估指标
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)# 执行模型评估,指定目标列作为评估标签
# 该步骤将计算模型在测试数据集上的指定评估指标
evaluator.run(tgt_columns="labels")

打卡图片:

image-20240707192022631

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/371171.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

14-20 Vision Transformer用AI的画笔描绘新世界

概述 毫无疑问,目前最受关注且不断发展的最重要的主题之一是使用人工智能生成图像、视频和文本。大型语言模型 (LLM) 已展示出其在文本生成方面的卓越能力。它们在文本生成方面的许多问题已得到解决。然而,LLM 面临的一个主要挑战是它们有时会产生幻觉反应。 最近推出的新模…

【Unity小技巧】Unity字典序列化

字典序列化 在 Unity 中&#xff0c;标准的 C# 字典&#xff08;Dictionary<TKey, TValue>&#xff09;是不能直接序列化的&#xff0c;因为 Unity 的序列化系统不支持非 Unity 序列化的集合类型。可以通过手写字典实现 效果&#xff1a; 实现步骤&#xff1a; 继承ISe…

【TB作品】51单片机 Proteus仿真 MAX7219点阵驱动数码管驱动

1、8乘8点阵模块&#xff08;爱心&#xff09; 数码管测试程序与仿真 实验报告: MAX7219 数码管驱动测试 一、实验目的 通过对 MAX7219 芯片的编程与控制&#xff0c;了解如何使用单片机驱动数码管显示数字&#xff0c;并掌握 SPI 通信协议的基本应用。 二、实验器材 51…

AI中药处方模型构建与案例

在中医领域,人工智能(AI)可以生成各种指令来辅助诊断、治疗和研究。 1. 诊断辅助指令: 根据患者的症状和体征,自动分析并生成可能的中医证候诊断建议。利用中医望闻问切四诊信息,智能识别关键症状,提供对应的中医辨证思路。2. 治疗建议指令: 根据辨证结果,自动推荐相应…

JVM专题之垃圾收集算法

标记清除算法 第一步:标记 (找出内存中需要回收的对象,并且把它们标记出来) 第二步:清除 (清除掉被标记需要回收的对象,释放出对应的内存空间) 缺点: 标记清除之后会产生大量不连续的内存碎片,空间碎片太多可能会导致以后在程序运行过程中需 要分配较大对象时,无法找到…

Python代码设置Excel工作表背景色或背景图

Excel是工作中数据处理和分析数据的重要工具。面对海量的数据和复杂的表格&#xff0c;如何提高工作效率、减少视觉疲劳并提升数据的可读性是不容忽视的问题。而给工作表设置合适的背景是表格优化的一个有效方式。为Excel工作表设置背景色或背景图不仅能够美化工作表&#xff0…

chrome 谷歌浏览器插件打包

1、找到id对应的字符串去搜索 C:\Users\<你的用户名>\AppData\Local\Google\Chrome\User Data\Default\Extensions2、选择根目录 直接加载下面的路径扩展可用&#xff1a;

AI绘画Stable Diffusion【图生图教程】:图片高清修复的三种方案详解,你一定能用上!(附资料)

大家好&#xff0c;我是画画的小强 今天给大家分享一下用AI绘画Stable Diffusion 进行 高清修复&#xff08;Hi-Res Fix&#xff09;&#xff0c;这是用于提升图像分辨率和细节的技术。在生成图像时&#xff0c;初始的低分辨率图像会通过放大算法和细节增强技术被转换为高分辨…

qt 如何添加子项目

首先我们正常流程创建一个项目文件&#xff1a; 这是我已经创建好的&#xff0c;请无视红线 然后找到该项目的文件夹&#xff0c;在文件夹下创建一个文件夹&#xff0c;再到创建好的文件夹下面创建一个 .pri 文件&#xff1a; &#xff08;创建文件夹&#xff09; &#xff08…

中国石油大学(华东)24计算机考研数据速览,计科学硕复试线288分!

中国石油大学&#xff08;华东&#xff09;计算机与通信工程学院是中国石油大学(华东)十三个教学院部之一&#xff0c;其前身是创建于1984年的计算机科学系&#xff0c;2001年撤系建院。伴随着学校50多年的风雨历程&#xff0c;计算机与通信工程学院也已经有了20多年的发展历史…

Python【打包exe文件两步到位】

Python打包Exe 安装 pyinstaller&#xff08;pip install pyinstaller&#xff09; 执行打包命令&#xff08;pyinstaller demo.py&#xff09; 打完包会生成 dist 文件夹&#xff0c;如下如

04.ffmpeg打印音视频媒体信息

目录 1、相关头文件 2、相关结构体 3、相关函数 4、函数详解 5、源码附上 1、相关头文件 #include <libavformat/avformat.h> 包含格式相关的函数和数据结构 #include <libavutil/avutil.h> 包含一些通用实用函数 2、相关结构体 AV…

【代码管理的必备工具:Git的基本概念与操作详解】

一、Git 初识 1.提出问题 不知道你工作或学习时&#xff0c;有没有遇到这样的情况&#xff1a;我们在编写各种⽂档时&#xff0c;为了防止⽂档丢失&#xff0c;更改失误&#xff0c;失误后能恢复到原来的版本&#xff0c;不得不复制出⼀个副本&#xff0c;比如&#xff1a; “…

多元微分学中可微、连续、存在问题

一、偏导存在 与一元证明相同&#xff0c;利用偏导定义式&#xff0c;证明偏导数左右极限存在且相同。 二、偏导连续 与一元证明相同&#xff0c;证明 三、极限存在 1、找一条路径&#xff0c;一般地找 y kx 2、代入f(x,y)&#xff0c;得f(x,kx) 3、证明f(x,kx)极限存在 注意&…

基于SpringBoot的休闲娱乐代理售票系统

本系统主要包括管理员和用户两个角色组成&#xff1b;主要包括&#xff1a;首页、个人中心、用户管理、折扣票管理、分类管理、订单信息管理、退票信息管理、出票信息管理、系统管理等功能的管理系统。 &#x1f495;&#x1f495;作者&#xff1a;Weirdo &#x1f495;&#x…

【数据结构】链表带环问题分析及顺序表链表对比分析

【C语言】链表带环问题分析及顺序表链表对比分析 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;C语言学习之路 文章目录 【C语言】链表带环问题分析及顺序表链表对比分析前言一.顺序表和链表对比1.1顺序表和链表的区别1.2缓存利用率&#…

隔离级别-隔离级别中的锁协议、隔离级别类型、隔离级别的设置、隔离级别应用

一、引言 1、DBMS除了采用严格的两阶段封锁协议来保证并发事务的可串行化&#xff0c;实现事务的隔离性&#xff0c;也可允许用户选择一个可以保证应用程序正确执行并且能够使并发度最大的隔离性等级 2、通常用隔离级别来描述隔离性等级&#xff0c;以下将主要介绍ANSI 92标准…

【python技巧】parser传入参数

参考网址: https://lightning.ai/docs/pytorch/LTS/api/pytorch_lightning.utilities.argparse.html#pytorch_lightning.utilities.argparse.add_argparse_args 1. 简单传入参数. parse_known_args()方法的作用就是把不在预设属性里的参数也返回,比如下面这个例子, 执行pytho…

算法的空间复杂度(C语言)

1.空间复杂度的定义 算法在临时占用储存空间大小的量度&#xff08;就是完成这个算法所额外开辟的空间&#xff09;&#xff0c;空间复杂度也使用大O渐进表示法来表示 注&#xff1a; 函数在运行时所需要的栈空间(储存参数&#xff0c;局部变量&#xff0c;一些寄存器信息等)…

vue.js微商城后台管理系统

一.需要运行的效果 20240701-231456 二.代码&#xff08;解析&#xff09; 首先&#xff0c;为项目添加依赖&#xff1a; yarn add element-plus --save yarn add vue-router4 --save 新建一个项目包&#xff0c;然后命名为商品管理&#xff0c;在components中新建几个vue文件…