从0书写一个softmax分类 李沐pytorch实战

 输出维度

在softmax 分类中 我们输出与类别一样多。 数据集有10个类别,所以网络输出维度为10。

 初始化权重和偏置

torch.norma 生成一个均值为 0,标准差为0.01,一个形状为size=(num_inputs, num_outputs)的张量

偏置生成一个num_outputs =10 的一维张量,并用0初始化 

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)###
b = torch.zeros(num_outputs, requires_grad=True)

requires_grad=True,PyTorch 会在后向传播过程中自动计算该张量的梯度,这对于优化模型参数非常重要。

sum 运算符工作机制:

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdim=True), X.sum(1, keepdim=True)

sum = 0 张量按列求和 sum = 1 张量按行求和 

定义softmax函数 

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

测试

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
print(X_prob)
print(X_prob.sum(1))

定义sofrmax模型:

softmax 回归模型 定义

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

 W.shape[0]表示权重的第一维大小

reshape 函数会根据原始张量 X 的元素总数和你提供的其他维度来计算出 -1 代表的维度

定义交叉熵损失函数:

回顾

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat[[0, 1], y])

y 张量是两个真实类别,第0类和第二类

y_hat 是对两个类别 在三种类别上的预测,真实的第0类预测结果为0.1,第2类预测结果为0.5

输出,取出y_hat的指定索引,

  • 对于第一个样本(索引 0),取 y_hat[0, 0],即 0.1
  • 对于第二个样本(索引 1),取 y_hat[1, 2],即 0.5
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])print(f'交叉熵损失为{cross_entropy(y_hat, y)}')

交叉熵损失为tensor([2.3026, 0.6931])

定义分类精度:

计算出 正确预测数量与总预测数量之比

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
print(accuracy(y_hat, y) / len(y))
  • y_hat 的形状为 (2, 3),这意味着:
    • 第一维的大小是 2,表示有 2 个样本(行)。
    • 第二维的大小是 3,表示每个样本有 3 个类别的预测概率(列)。

因此,y_hat.shape[1] 返回的是第二维的大小,也就是 3。这个值表示每个样本的类别数。在这个例子中,y_hat 中的每一行包含了对应样本对 3 个类别的预测概率。

print :y_hat.argmax(axis=1) 预测值张量为[2,2],与y = torch.tensor([0, 2])做对比

将布尔张量转换为整型张量

将[False,True],转换为0,1形式

(cmp.type(y.dtype).sum())

报错RuntimeError: DataLoader worker (pid(s) 12452, 3084, 29000, 29444) exited unexpectedly解决方法:

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
test_iter.num_workers = 0
train_iter.num_workers = 0

再训练迭代器和测试迭代器后加入

定义评估模型准确率函数:

def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]

 if isinstance(net, torch.nn.Module)是一个Python内置函数,用于检查对象net是否是torch.nn.Module类的实例。

  • 创建一个Accumulator对象,用于累加正确预测的数量和预测的总数量。Accumulator类会存储两个值。
  • 调用模型net对输入X进行预测,得到预测结果,然后使用accuracy函数计算预测的准确数量。y.numel()返回标签y中的元素总数(即样本数量),这两个值一起传递给metric.add()进行累加。

代码分析:

    def add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]
  • zip函数将self.data(当前存储的累积值)和args(输入的多个参数)配对。假设self.data是一个包含n个元素的列表,而args也是一个包含n个元素的可变参数列表。
  • 例如,如果self.data = [3.0, 5.0],而args = (2, 1)zip会生成[(3.0, 2), (5.0, 1)]的迭代器。
  • [a + float(b) for a, b in zip(self.data, args)]是一个列表推导式,用于遍历zip生成的配对。在这个过程中:
    • aself.data中的当前元素。
    • bargs中的当前元素。
  • 对每一对(a, b),该表达式计算a + float(b),将b转换为浮点数并与a相加。

 预测:

def predict_ch3(net, test_iter, n=9):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)
d2l.plt.show()

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/425505.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计网】数据链路层:概述之位置|地位|链路|数据链路|帧

✨ Blog’s 主页: 白乐天_ξ( ✿>◡❛) 🌈 个人Motto:他强任他强,清风拂山岗! 💫 欢迎来到我的学习笔记! ① ② ③ ④ ⑤ ⑥ ⑦ ⑧ ⑨ ⑩ 1. 在OSI体系结构中的位置 1. 位置:数…

ICMP

目录 1. 帧格式2. ICMPv4消息类型(Type = 0,Code = 0)回送应答 /(Type = 8,Code = 0)回送请求(Type = 3)目标不可达(Type = 5)重定向(Type = 11)ICMP超时(Type = 12)参数3. ICMPv6消息类型回见TCP/IP 对ICMP协议作介绍 ICMP(Internet Control Message Protocol…

即插即用!高德西交的PriorDrive:统一的矢量先验地图编码,辅助无图自动驾驶

Driving with Prior Maps: Unified Vector Prior Encoding for Autonomous Vehicle Mapping 论文主页:https://misstl.github.io/PriorDrive.github.io/ 论文链接:https://arxiv.org/pdf/2409.05352 代码链接:https://github.com/missTL/Pr…

基于python+django+vue的学生成绩管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于协同过滤pythondjangovue…

用Python解决综合评价问题_模糊综合评价,决策树与灰色关联分析

一:模糊综合评价 模糊综合评价是一种有效的处理不确定性和模糊性的评价方法,特别是在人才评价等领域。它允许我们综合考虑多个评价指标,并给出一个综合的评价结果。以下是利用模糊综合评价对人才进行评价的步骤: 确定评价指标&am…

Git常用指令整理【新手入门级】【by慕羽】

Git 是一个分布式版本控制系统,主要用于跟踪和管理源代码的更改。它允许多名开发者协作,同时提供了强大的功能来管理项目的历史记录和不同版本。本文主要记录和整理,个人理解的Git相关的一些指令和用法 文章目录 一、git安装 & 创建git仓…

【AI大模型】ChatGPT模型原理介绍(上)

目录 🍔 什么是ChatGPT? 🍔 GPT-1介绍 2.1 GPT-1模型架构 2.2 GPT-1训练过程 2.2.1 无监督的预训练语言模型 2.2.2 有监督的下游任务fine-tunning 2.2.3 整体训练过程架构图 2.3 GPT-1数据集 2.4 GPT-1模型的特点 2.5 GPT-1模型总结…

深度学习-神经网络

文章目录 一、基本组成单元:神经元二、神经网络层三、偏置与权重四、激活函数1.激活函数的作用2.常见的激活函数1).Sigmoid2).Tanh函数3).ReLU函数 五、优点与缺点六、总结 神经网络(Neural Network, NN)是一种模拟人类大脑工作方式的计算模型…

Debian11.9镜像基于jre1.8的Dockerfile

Debian11.9基于jre1.8的Dockerfile编写 # 使用Debian 11.9作为基础镜像 FROM debian:11.9 # 维护者信息(建议使用LABEL而不是MAINTAINER,因为MAINTAINER已被弃用) LABEL maintainer"caibingsen" # 创建一个目录来存放jre …

LabVIEW提高开发效率技巧----VI服务器和动态调用

VI服务器(VI Server)和动态调用是LabVIEW中的两个重要功能,可以有效提升程序的灵活性、模块化和可扩展性。通过这两者的结合,开发者可以在运行时动态加载和调用VI(虚拟仪器),实现更为复杂的应用…

【我的 PWN 学习手札】Unsortedbin Attack

前言 Unsortedbin Attack不能说是一种getshell的方式,而只能说是一种利用手法。在glibc2.28之前有效,条件是存在uaf,效果是能在某一地址写一个大数(glibc上的一个地址)。 一、Unsortedbin的unlink过程 unsortedbin …

Android Framework(六)WMS-窗口显示流程——窗口内容绘制与显示

文章目录 窗口显示流程明确目标 窗户内容绘制与显示流程窗口Surface状态完整流程图 应用端处理finishDrawingWindow 的触发 system_service处理WindowState状态 -- COMMIT_DRAW_PENDING本次layout 流程简述 窗口显示流程 目前窗口的显示到了最后一步。 在 addWindow 流程中&…

C语言中数据类型

一、C 语言中数据类型 基本数据类型: 整型(int):用于存储整数,如:1、2、3等。字符型(char):用于存储单个字符,如:‘a’、‘b’、c’等。浮点型&a…

中秋献礼!2024年中科院一区极光优化算法+分解对比!VMD-PLO-Transformer-LSTM多变量时间序列光伏功率预测

中秋献礼!2024年中科院一区极光优化算法分解对比!VMD-PLO-Transformer-LSTM多变量时间序列光伏功率预测 目录 中秋献礼!2024年中科院一区极光优化算法分解对比!VMD-PLO-Transformer-LSTM多变量时间序列光伏功率预测效果一览基本介…

一种多策略改进小龙虾智能优化算法MSCOA 改进策略:种群混沌映射初始化+透镜成像反向学习+黄金正弦变异策略

一种多策略改进小龙虾智能优化算法MSCOA 改进策略:种群初始化精英反向透镜成像反向学习黄金正弦变异策略 文章目录 一、小龙虾COA基本原理二、改进策略2.1种群初始化 映射2.2 透镜成像反向学习2.3 黄金正弦变异策略 三、实验结果四、核心代码五、代码获取六、总结 一…

每日一个数据结构-跳表

文章目录 什么是跳表?示意图跳表的基本原理跳表的操作跳表与其他数据结构的比较 跳表构造过程 什么是跳表? 跳表(Skip List)是一种随机化的数据结构,它通过在有序链表上增加多级索引来实现快速查找、插入和删除操作。…

react hooks--useState

概述 useState 可以使函数组件像类组件一样拥有 state,也就说明函数组件可以通过 useState 改变 UI 视图。那么 useState 到底应该如何使用,底层又是怎么运作的呢,首先一起看一下 useState 。 问题:Hook 是什么? 一个 Hook 就是…

TensorRT-LLM——优化大型语言模型推理以实现最大性能的综合指南

引言 随着对大型语言模型 (LLM) 的需求不断增长,确保快速、高效和可扩展的推理变得比以往任何时候都更加重要。NVIDIA 的 TensorRT-LLM 通过提供一套专为 LLM 推理设计的强大工具和优化,TensorRT-LLM 可以应对这一挑战。TensorRT-LLM 提供了一系列令人印…

Double Write

优质博文:IT-BLOG-CN 一、存在的问题 为什么需要Double Write: InnoDB的PageSize是16kb,其数据校验也是针对这16KB来计算的,将数据写入磁盘是以Page为单位的进行操作的。而计算机硬件和操作系统,写文件是以4KB作为基…

Python基础语法(1)上

常量和表达式 我们可以把 Python 当成一个计算器,来进行一些算术运算。 print(1 2 - 3) print(1 2 * 3) print(1 2 / 3) 这里我们可能会有疑问,为什么不是1.6666666666666667呢? 其实在编程中,一般没有“四舍五入”这样的规则…