昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/367806.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【目标检测】DINO

一、引言 论文&#xff1a; DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection 作者&#xff1a; IDEA 代码&#xff1a; DINO 注意&#xff1a; 该算法是在Deformable DETR、DAB-DETR、DN-DETR基础上的改进&#xff0c;在学习该算法前&#…

决策树算法的原理与案例实现

一、绪论 1.1 决策树算法的背景介绍 1.2 研究决策树算法的意义 二、决策树算法原理 2.1 决策树的基本概念 2.2 决策树构建的基本思路 2.2 决策树的构建过程 2.3 决策树的剪枝策略 三、决策树算法的优缺点 3.1 决策树算法的优势 3.2 决策树算法的局限性 3.3 决策树算…

Vue报错:Component name “xxx” should always be multi-word vue/multi-word-component

问题&#xff1a;搭建脚手架时报错&#xff0c;具体错误如下&#xff1a; ERROR in [eslint] E:\personalProject\VueProjects\vueproject2\src\components\Student.vue10:14 error Component name "Student" should always be multi-word vue/multi-word-compon…

【分布式数据仓库Hive】常见问题及解决办法

目录 一、启动hive时发现log4j版本和hadoop的版本有冲突 解决办法&#xff1a;删除hive下高版本的slf4j 二、启动hive报错 Exception in thread "main" java.lang.NoSuchMethodError:com.google.common.base.Preconditions.checkArgument(ZLjava/lang/Object;)V …

Elasticsearch (1):ES基本概念和原理简单介绍

Elasticsearch&#xff08;简称 ES&#xff09;是一款基于 Apache Lucene 的分布式搜索和分析引擎。随着业务的发展&#xff0c;系统中的数据量不断增长&#xff0c;传统的关系型数据库在处理大量模糊查询时效率低下。因此&#xff0c;ES 作为一种高效、灵活和可扩展的全文检索…

分别使用netty和apache.plc4x测试读取modbus协议的设备信号

记录一下常见的工业协议数据读取方法 目录 前言Modbus协议说明Netty 读取测试使用plc4x 读取测试结束语 前言 Modbus 是一种通讯协议&#xff0c;用于在工业控制系统中进行数据通信和控制。Modbus 协议主要分为两种常用的变体&#xff1a;Modbus RTU 和 Modbus TCP/IP Modbus …

嵌入式Linux之Uboot简介和移植

uboot简介 uboot 的全称是 Universal Boot Loader&#xff0c;uboot 是一个遵循 GPL 协议的开源软件&#xff0c;uboot是一个裸机代码&#xff0c;可以看作是一个裸机综合例程。现在的 uboot 已经支持液晶屏、网络、USB 等高级功能。 也就是说&#xff0c;可以在没有系统的情况…

苹果手机收不到短信怎么恢复?90%的人都在这么做

在使用苹果手机的过程中&#xff0c;有时候会遇到无法接收短信的问题。这不仅影响正常的沟通&#xff0c;还可能错过重要的通知和验证码。那么&#xff0c;手机收不到短信怎么恢复呢&#xff1f;别担心&#xff0c;90%的人都在使用这些简单而有效的方法来解决这一问题。 本文将…

Halcon支持向量机

一 支持向量机 1 支持向量机介绍&#xff1a; 支持向量机(Support Vector Machine&#xff0c;SVM)是Corinna Cortes和Vapnik于1995年首先提出的&#xff0c;它在解决小样本、非线性及高维模式识别表现出许多特有的优势。 2 支持向量机原理: 在n维空间中找到一个分类超平面…

14 卡尔曼滤波及代码实现

文章目录 14 卡尔曼滤波及代码实现14.0 基本概念14.1 公式推导14.2 代码实现 14 卡尔曼滤波及代码实现 14.0 基本概念 卡尔曼滤波是一种利用线性系统状态方程&#xff0c;通过系统输入输出观测数据&#xff0c;对系统状态进行最优估计的算法。由于观测数据包括系统中的噪声和…

React Native V0.74 — 稳定版已发布

嗨,React Native开发者们, React Native 世界中令人兴奋的消息是,V0.74刚刚在几天前发布,有超过 1600 次提交。亮点如下: Yoga 3.0New Architecture: Bridgeless by DefaultNew Architecture: Batched onLayout UpdatesYarn 3 for New Projects让我们深入了解每一个新亮点…

移动智能终端数据安全管理方案

随着信息技术的飞速发展&#xff0c;移动设备已成为企业日常运营不可或缺的工具。特别是随着智能手机和平板电脑等移动设备的普及&#xff0c;这些设备存储了大量的个人和敏感数据&#xff0c;如银行信息、电子邮件等。员工通过智能手机和平板电脑访问企业资源&#xff0c;提高…

【vue3】【vant】 移动端中国传统文化和民间传说案例

更多项目点击&#x1f446;&#x1f446;&#x1f446;完整项目成品专栏 【vue3】【vant】 移动端中国传统文化和民间传说案例 获取源码方式项目说明&#xff1a;其中功能包括项目包含&#xff1a;项目运行环境运行截图和视频 获取源码方式 加Q群&#xff1a;632562109项目说…

Linux_管道通信

目录 一、匿名管道 1、介绍进程间通信 2、理解管道 3、管道通信 4、用户角度看匿名管道 5、内核角度看匿名管道 6、代码实现匿名管道 6.1 创建子进程 6.2 实现通信 7、匿名管道阻塞情况 8、匿名管道的读写原子性 二、命名管道 1、命名管道 1.1 命名管道通信 …

源代码层面分析Appium-inspector工作原理

Appium-inspector功能 Appium Inspector 基于 Appium 框架&#xff0c;Appium 是一个开源工具&#xff0c;用于自动化移动应用&#xff08;iOS 和 Android&#xff09;和桌面应用&#xff08;Windows 和 Mac&#xff09;。Appium 采用了客户端-服务器架构&#xff0c;允许用户通…

C++初学者指南-3.自定义类型(第一部分)-异常

C初学者指南-3.自定义类型(第一部分)-异常 文章目录 C初学者指南-3.自定义类型(第一部分)-异常简介什么是异常&#xff1f;第一个示例用途:报告违反规则的行为异常的替代方案标准库异常处理 问题和保证资源泄露使用 RAII 避免内存泄漏&#xff01;析构函数&#xff1a;不要让异…

Taogogo Taocms v3.0.2 远程代码执行漏洞(CVE-2022-25578)

前言 CVE-2022-25578 是一个存在于 Taogogo Taocms v3.0.2 中的代码注入漏洞。此漏洞允许攻击者通过任意编辑 .htaccess 文件来执行代码注入。 漏洞详情 漏洞描述&#xff1a;攻击者可以利用此漏洞上传一个 .htaccess 文件到网站&#xff0c;并在文件中注入恶意代码&#xf…

CesiumJS【Basic】- #058 绘制网格填充多边形(Entity方式)-使用shader

文章目录 绘制网格填充多边形(Entity方式)-使用shader1 目标2 代码2.1 main.ts绘制网格填充多边形(Entity方式)-使用shader 1 目标 使用Entity方式绘制绘制网格填充多边形 - 使用shader 2 代码 2.1 main.ts import * as Cesium from cesium;// 创建 Cesium Viewer 实例…

主流国产服务器操作系统技术分析

主流国产服务器操作系统 信创 "信创"&#xff0c;即信息技术应用创新&#xff0c;作为科技自立自强的核心词汇&#xff0c;在我国信息化建设的进程中扮演着至关重要的角色。自2016年起步&#xff0c;2020年开始蓬勃兴起&#xff0c;信创的浪潮正席卷整个信息与通信技…

程序员AI提效案例:统计B站课程耗时情况

文章目录 一&#xff0c;时长统计需求二&#xff0c;一波三折三&#xff0c;终极方案 AIJava总结 今天为了写一篇博客&#xff0c;这篇博客介绍了B站的一个Java项目&#xff0c;这个项目分为三个阶段&#xff1a; 初级篇高级篇运维篇 一&#xff0c;时长统计需求 我想根据每个…