【NLP 9、实践 ① 五维随机向量交叉熵多分类】

目录

五维向量交叉熵多分类

规律:

实现:

1.设计模型

2.生成数据集

3.模型测试

4.模型训练

5.对训练的模型进行验证

调用模型


你的平静,是你最强的力量

                                —— 24.12.6

五维向量交叉熵多分类

规律:

x是一个五维(索引)向量,对x做五分类任务

改用交叉熵实现一个多分类任务,五维随机向量中最大的数字在哪维就属于哪一类


实现:

1.设计模型

Linear():模型函数中定义线性层

activation = nn.Softmax(dim=1):定义激活层为softmax激活函数

nn.CrossEntropyLoss() / nn.functional.cross_entropy:定义交叉熵损失函数

pyTorch中定义的交叉熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数,对数据进行归一化

经过 Softmax 归一化后,输出向量的每个元素可以被解释为样本属于相应类别的概率。这使得我们能够直接比较不同类别上的概率大小,并且与真实的类别概率分布(如one-hot编码)进行合理的对比。

例如,在一个三分类问题中,经过 Softmax 后的输出可能是[0.2,0.3,0.5],我们可以直观地说样本属于第三类的概率是 0.5,这是一个符合概率意义的解释

forward函数,前向计算,定义网络的使用方式,声明模型计算过程

# 1.设计模型
class TorchModel(nn.Module):def __init__(self, input_size):super(TorchModel, self).__init__()# 预测出一个五维的向量,五维向量代表五个类别上的概率分布self.linear = nn.Linear(input_size, 5)  # 线性层# 类交叉熵写法:CrossEntropyLoss()     函数交叉熵写法:cross_entropy# nn.CrossEntropyLoss() pycharm交叉的熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数self.loss = nn.functional.cross_entropy # loss函数采用交叉熵损失self.activation = nn.Softmax(dim=1)# 当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 输入过第一个网络层y_pred = self.linear(x)  # (batch_size, input_size) -> (batch_size, 1)if y is not None:return self.loss(y_pred, y)  # 预测值和真实值计算损失else:return self.activation(y_pred)  # 输出预测结果# return y_pred

2.生成数据集

由于题目要求,要在一个五维随机向量中查找标量最大的数所在维度,所以用np.random函数随机生成一个五维向量,然后通过np.argmax函数找出生成向量中最大标量所对应的维度,并将其作为数据 x标注 y 返回

当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到 "[ ]",换句话说当y是单独的一个数(标量)时,才需要加“[ ]”

而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起

# 2.生成数据集标签label   数据构建
# 生成一个样本, 样本的生成方法,代表了我们要学习的规律,随机生成一个5维向量,如果第一个值大于第五个值,认为是正样本,反之为负样本
def build_sample():x = np.random.random(5)# 获取最大值对应的索引max_index = np.argmax(x)return x, max_index# 随机生成一批样本
# 正负样本均匀生成
def build_dataset(total_sample_num):X = []Y = []# 随机生成样本,total_sample_num 生成的随机样本数for i in range(total_sample_num):x, y = build_sample()X.append(x)# 当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到"[]",换句话说当y是单独得一个数(标量)时,才需要加“[]”# 而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起Y.append(y)X_array = np.array(X)Y_array = np.array(Y)# 一般torch中的Long整形类型用来判定类型return torch.FloatTensor(X_array), torch.LongTensor(Y_array)

3.模型测试

用来测试每轮模型预测的精确度

model.eval():声明模型框架在这个函数中不做训练

with torch.no_grad():在模型测试的部分中,声明是测试函数,不计算梯度,增加模型训练效率

zip():zip 函数是一个内置函数,用于将多个可迭代对象(如列表、元组、字符串等)中对应的元素打包成一个个元组,然后返回由这些元组组成的可迭代对象(通常是一个 zip 对象)。如果各个可迭代对象的长度不一致,那么 zip 操作会以最短的可迭代对象长度为准。

# 3.模型测试
# 用来测试每轮模型的准确率
def evaluate(model):model.eval()test_sample_num = 100x, y = build_dataset(test_sample_num)print("本次预测集中共有%d个正样本,%d个负样本" % (sum(y), test_sample_num - sum(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)  # 模型预测 model.forward(x)for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比# np.argmax是求最大数所在维,max求最大数,torch.argmax是求最大数所在维if torch.argmax(y_p) == int(y_t):correct += 1  # 正确预测加一else:wrong += 1  # 错误预测加一print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))return correct / (correct + wrong)

4.模型训练

① 配置参数        

② 建立模型

③ 选择优化器(Adam)

④ 读取训练集

⑤ 训练过程

        Ⅰ、model.train():设置训练模式

        Ⅱ、对训练集样本开始循环训练(循环取出训练数据)

        Ⅲ、根据模型函数和损失函数的定义计算模型损失

        Ⅳ、计算梯度

        Ⅴ、通过梯度用优化器更新权重

        Ⅵ、计算完一轮训练数据后梯度进行归零,下一轮重新计算

torch.save(model.state_dict(), "model.pt"):模型保存model.pt文件

一般任务不同只需更改数据读取(步骤③)模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同

# 4.模型训练
def main():# 配置参数epoch_num = 20  # 训练轮数batch_size = 20  # 每次训练样本个数train_sample = 5000  # 每轮训练总共训练的样本总数input_size = 5  # 输入向量维度learning_rate = 0.001  # 学习率# ① 建立模型model = TorchModel(input_size)# ② 选择优化器optim = torch.optim.Adam(model.parameters(), lr=learning_rate)log = []# ③ 创建训练集,正常任务是读取训练集train_x, train_y = build_dataset(train_sample)# 训练过程# 轮数进行自定义for epoch in range(epoch_num):model.train()watch_loss = []# ④ 读取数据集for batch_index in range(train_sample // batch_size):x = train_x[batch_index * batch_size : (batch_index + 1) * batch_size]y = train_y[batch_index * batch_size : (batch_index + 1) * batch_size]# ⑤ 计算lossloss = model(x, y)  # 计算loss  model.forward(x,y)# ⑥ 计算梯度loss.backward()  # 计算梯度# ⑦ 权重更新optim.step()  # 更新权重# ⑧ 梯度归零optim.zero_grad()  # 梯度归零watch_loss.append(loss.item())# 一般任务不同只需更改数据读取(步骤③)和模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model)  # 测试本轮模型结果log.append([acc, float(np.mean(watch_loss))])# 保存模型torch.save(model.state_dict(), "model.pt")# 画图print(log)plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线plt.legend()plt.show()return

5.对训练的模型进行验证

调用main函数

if __name__ == "__main__":main()


调用模型

model.eval():声明模型框架在这个函数中不做训练

predict("model.pt", test_vec):调用模型存储的文件model.pt,通过调用模型对数据进行预测

# 使用训练好的模型做预测
def predict(model_path, input_vec):input_size = 5model = TorchModel(input_size)# 加载训练好的权重model.load_state_dict(torch.load(model_path, weights_only=True))# print(model.state_dict())model.eval()  # 测试模式,不计算梯度with torch.no_grad():# 输入一个真实向量转成Tensor,让模型forward一下result = model.forward(torch.FloatTensor(input_vec))  # 模型预测for vec, res in zip(input_vec, result):# python中,round函数是对浮点数进行四舍五入print("输入:%s, 预测类别:%s, 概率值:%s" % (vec, torch.argmax(res), res))  # 打印结果if __name__ == "__main__":test_vec = [[0.97889086,0.15229675,0.31082123,0.03504317,0.88920843],[0.74963533,0.5524256,0.95758807,0.95520434,0.84890681],[0.00797868,0.67482528,0.13625847,0.34675372,0.19871392],[0.09349776,0.59416669,0.92579291,0.41567412,0.1358894]]predict("model.pt", test_vec)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/485570.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

01-Chromedriver下载与配置(mac)

下载地址: 这里我用的最后一个,根据自己chrome浏览器选择相应的版本号即可 ChromeDriver官网下载地址:https://sites.google.com/chromium.org/driver/downloads ChromeDriver官网最新版下载地址:https://googlechromelabs.git…

Redis(上)

Redis 基础 什么是 Redis? Redis (REmote DIctionary Server)是一个基于 C 语言开发的开源 NoSQL 数据库(BSD 许可)。与传统数据库不同的是,Redis 的数据是保存在内存中的(内存数据库&#xf…

JS学习(1)(基本概念与作用、与HTML、CSS区别)

目录 一、JavaScript是什么? (1)基本介绍 (2)简称:JS? 二、JavaScript的作用。 三、HTML、CSS、JS之间的关系。 (1)html、css。 (2)JavaScript。 …

使用AI工具Screenshot to Code将UI设计图翻译成代码

一、获取openAI apikey。 一般有两种方式,一种是到openAI官网注册账号,付费申请GPT4的apikey。另一种是某宝买代理。我这里采用第二种。 二、安装Screenshot to Code 1.到github下载源码。 2.启动,两种方式:源码启动和docker启动…

ETCD的封装和测试

etcd是存储键值数据的服务器 客户端通过长连接watch实时更新数据 场景: 当主机A给服务器存储 name: 小王 主机B从服务器中查name ,得到name-小王 当主机A更改name 小李 服务器实时通知主机B name 已经被更改成小李了。 应用:服务注册与发…

Github 2024-12-01 开源项目月报 Top20

根据Github Trendings的统计,本月(2024-12-01统计)共有20个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目10TypeScript项目9Go项目2HTML项目1Shell项目1Jupyter Notebook项目1屏幕截图转代码应用 创建周期:114 天开发语言:TypeScript, Py…

李飞飞:Agent AI 多模态交互的前沿探索

发布于:2024 年 11 月 27 日 星期三 北京 #RAG #李飞飞 #Agent #多模态 #大模型 Agent AI在多模态交互方面展现出巨大潜力,通过整合各类技术,在游戏、机器人、医疗等领域广泛应用。如游戏中优化NPC行为,机器人领域实现多模态操作等。然而,其面临数据隐私、偏见、可解释性…

C语言期末考试——重点考点

目录 1.C语言的结构 2.三种循环结构 3.逻辑真假判断 4. printf函数 5. 强制类型转化 6. 多分支选择结构 7. 标识符的定义 8. 三目运算符 1.C语言的结构 选择结构、顺序结构、循环结构 2.三种循环结构 for、while、do-while 3.逻辑真假判断 C语言用0表示false,用非0(不…

ansible基础教程(下)

一、playbook 简介: playbook 是 ansible 用于配置,部署,和管理被控节点的剧本。 通过 playbook 的详细描述,执行其中的一系列 tasks ,可以让远端主机达到预期的状态。 使用场景: 像执行shell命令与写…

mid360使用cartorapher进行3d建图导航

1. 添加urdf配置文件&#xff1a; 添加IMU配置关节点和laser关节点 <!-- imu livox --> <joint name"livox_frame_joint" type"fixed"> <parent link"base_link" /> <child link"livox_frame" /> <o…

聚合支付系统/官方个人免签系统/三方支付系统稳定安全高并发 附教程

聚合支付系统/官方个人免签系统/三方支付系统稳定安全高并发 附教程 系统采用FastAdmin框架独立全新开发&#xff0c;安全稳定,系统支持代理、商户、码商等业务逻辑。 针对最近一些JD&#xff0c;TB等业务定制&#xff0c;子账号业务逻辑API 非常详细&#xff0c;方便内置…

CV工程师专用键盘开源项目硬件分析

1、前言 作为一个电子发烧友&#xff0c;你是否有遇到过这样的问题呢。当我们去查看函数定义的时候&#xff0c;需要敲击鼠标右键之后选择go to definition。更高级一些&#xff0c;我们使用键盘的快捷键来查看定义&#xff0c;这时候可以想象一下&#xff0c;你左手按下ALT&a…

【Redis】深入解析Redis缓存机制:全面掌握缓存更新、穿透、雪崩与击穿的终极指南

文章目录 一、Redis缓存机制概述1.1 Redis缓存的基本原理1.2 常见的Redis缓存应用场景 二、缓存更新机制2.1 缓存更新的策略2.2 示例代码&#xff1a;主动更新缓存 三、缓存穿透3.1 缓存穿透的原因3.2 缓解缓存穿透的方法3.3 示例代码&#xff1a;使用布隆过滤器 四、缓存雪崩4…

ABAP - 系统集成之SAP的数据同步到OA(泛微E9)服务器数据库

需求背景 项目经理说每次OA下单都需要调用一次SAP的接口获取数据&#xff0c;导致效率太慢了&#xff0c;能否把SAP的数据保存到OA的数据库表里&#xff0c;这样OA可以直接从数据库表里获取数据效率快很多。思来想去&#xff0c;提供了两个方案。 在集群SAP节点下增加一个SQL S…

2023年华数杯数学建模A题隔热材料的结构优化控制研究解题全过程文档及程序

2023年华数杯全国大学生数学建模 A题 隔热材料的结构优化控制研究 原题再现&#xff1a; 新型隔热材料 A 具有优良的隔热特性&#xff0c;在航天、军工、石化、建筑、交通等高科技领域中有着广泛的应用。   目前&#xff0c;由单根隔热材料 A 纤维编织成的织物&#xff0c;…

MongoDB性能监控工具

mongostat mongostat是MongoDB自带的监控工具&#xff0c;其可以提供数据库节点或者整个集群当前的状态视图。该功能的设计非常类似于Linux系统中的vmstat命令&#xff0c;可以呈现出实时的状态变化。不同的是&#xff0c;mongostat所监视的对象是数据库进程。mongostat常用于…

SpringBoot 赋能:精铸超稳会员制医疗预约系统,夯实就医数据根基

1绪论 1.1开发背景 传统的管理方式都在使用手工记录的方式进行记录&#xff0c;这种方式耗时&#xff0c;而且对于信息量比较大的情况想要快速查找某一信息非常慢&#xff0c;对于会员制医疗预约服务信息的统计获取比较繁琐&#xff0c;随着网络技术的发展&#xff0c;采用电脑…

电子商务人工智能指南 3/6 - 聊天机器人和客户服务

介绍 81% 的零售业高管表示&#xff0c; AI 至少在其组织中发挥了中等至完全的作用。然而&#xff0c;78% 的受访零售业高管表示&#xff0c;很难跟上不断发展的 AI 格局。 近年来&#xff0c;电子商务团队加快了适应新客户偏好和创造卓越数字购物体验的需求。采用 AI 不再是一…

mock.js介绍

mock.js http://mockjs.com/ 1、mock的介绍 *** 生成随机数据&#xff0c;拦截 Ajax 请求。** 通过随机数据&#xff0c;模拟各种场景&#xff1b;不需要修改既有代码&#xff0c;就可以拦截 Ajax 请求&#xff0c;返回模拟的响应数据&#xff1b;支持生成随机的文本、数字…

优化LabVIEW数据运算效率的方法

在LabVIEW中进行大量数据运算时&#xff0c;提升计算效率并减少时间占用是开发过程中常遇到的挑战。为此&#xff0c;可以从多个角度着手优化&#xff0c;包括合理选择数据结构与算法、并行处理、多线程技术、硬件加速、内存管理和界面优化等。通过采用这些策略&#xff0c;可以…