【PyTorch】softmax回归

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
  • 2. Q&A

1. 模型与代码实现

1.1. 模型

  • 背景
    在分类问题中,模型的输出层是全连接层,每个类别对应一个输出。我们希望模型的输出 y ^ j \hat{y}_j y^j可以视为属于类 j j j的概率,然后选择具有最大输出值的类别作为我们的预测。
    softmax函数能够将未规范化的输出变换为非负数并且总和为1,同时让模型保持可导的性质,而且不会改变未规范化的输出之间的大小次序
  • softmax函数
    y ^ = s o f t m a x ( o ) \mathbf{\hat{y}}=\mathrm{softmax}(\mathbf{o}) y^=softmax(o)其中 y ^ j = e x p ( o j ) ∑ k e x p ( o k ) \hat{y}_j=\frac{\mathrm{exp}({o_j})}{\sum_{k}\mathrm{exp}({o_k})} y^j=kexp(ok)exp(oj)
  • softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定,因此,softmax回归是一个线性模型
  • 为了避免将softmax的输出直接送入交叉熵损失造成的数值稳定性问题,将softmax和交叉熵损失结合在一起,具体做法是:不将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的输出,并同时计算softmax及其对数。因此定义交叉熵损失函数时也进行了softmax运算

1.2. 代码实现

import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter# 全局参数设置
batch_size = 256
num_workers = 0
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')writer = SummaryWriter()# 加载数据集
root = "./dataset"
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = FashionMNIST(root=root, train=True, transform=transform, download=True
)
mnist_test = FashionMNIST(root=root, train=False, transform=transform, download=True
)
dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers
)
dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers
)# 定义神经网络
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)).to(device)# 初始化模型参数
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.constant_(m.bias, val=0)
net.apply(init_weights)# 定义损失函数
criterion = nn.CrossEntropyLoss(reduction='none')# 定义优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())for epoch in range(num_epochs):# 训练模型net.train()train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in dataloader_train:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)optimizer.zero_grad()loss.mean().backward()optimizer.step()train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())train_loss = train_metrics[0] / train_metrics[2]train_acc = train_metrics[1] / train_metrics[2]# 测试模型net.eval()with torch.no_grad():    test_metrics = Accumulator(2)   # 测试准确度总和、样本数for X, y in dataloader_test:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)test_metrics.add(accuracy(y_hat, y), y.numel())test_acc = test_metrics[0] / test_metrics[1]writer.add_scalars("metrics", {'train_loss': train_loss, 'train_acc': train_acc, 'test_acc': test_acc}, epoch)
writer.close()   

输出结果:
tensorboard

2. Q&A

  • 运行过程中出现以下警告:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

该警告的大致意思是给定的NumPy数组不可写,并且PyTorch不支持不可写的张量。这意味着你可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。在本程序的其余部分,此类警告将被抑制。因此需要修改C:\Users\%UserName%\anaconda3\envs\%conda_env_name%\lib\site-packages\torchvision\datasets\mnist.py的第498行,将return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)中的False改成True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/209528.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于你对 Zookeeper 的理解

看看普通人和高手是如何回答这个问题的? 普通人 Zookeeper 是一种开放源码的分布式应用程序协调服务 是一个分布式的小文件存储系统 一般对开发者屏蔽分布式应用开发过过程种的底层细节 用来解决分布式集群中应用系统的一致性问题 高手 对于 Zookeeper 的理解…

win10使用copilot(尝试中)

一、 Microsoft account | Sign In or Create Your Account Today – Microsoft 一路next全部点好【1】 二、 查看当前win10的版本,cmd输入命令winver 三、 修改区域为美国 四、更新和安全 Reference 【1】完美|在 Win10 强行开启 Win11 的独有功能…

IDEA2023找不到 Allow parallel run

我的idea版本:2023.1.4 第一步:点击Edit Configrations 第二步:点击Modify options 第三步:勾选Allow multiple instances 最后点击Apply应用一下 ok,问题解决!

COMP4121Advanced Algorithms

COMP4121Advanced Algorithms WeChat:yj4399_ Sina Visitor System

使用凌鲨进行内网穿透

为了方便在本地进行开发和调试工作,有时候需要安全地连接内网或Kubernetes集群中的服务。 在net proxy server中可以限制访问用户,也可以设置端口转发的密码。 使用 连接端口转发服务 列出可转发端口 可转发端口是服务端设置的,不会暴露真…

开会做笔记的时候用什么软件比较好?

在工作生涯中,会经历很多大大小小的会议,而如何快速准确记录下会议上重要的内容,成了很多上班族的必修课。在会上做笔记,选择什么样的工具才能事半功倍,成了一个值得深思的问题。而经过一段时间的测评后,我…

策略梯度简明教程

策略梯度方法 (PG:Policy Gradient) 是强化学习 (RL:Reinforcement Learning) 中常用的算法。 1、从库里的本能开始 PG的原理很简单:我们观察,然后行动。人类根据观察采取行动。 引用斯蒂芬库里的一句话: 你必须依靠…

【C++初阶】六、类和对象(初始化列表、static成员、友元、内部类)

相关代码gitee自取: C语言学习日记: 加油努力 (gitee.com) 接上期: 【C初阶】五、类和对象 (日期类的完善、流运算符重载函数、const成员、“&”取地址运算符重载)-CSDN博客 目录 ​​​​​​​一 . 初始化列表 构造函数…

UiPath学习笔记

文章目录 前言RPA介绍UiPath下载安装组件内容 前言 最近有一个项目的采集调研涉及到了客户端的采集,就取了解了一下RPA和UIPATH,记录一下 RPA介绍 RPA(Robotic Process Automation:机器人处理自动化),是…

Ubuntu systemd-analyze命令(系统启动性能分析工具:分析系统启动时间,找出可能导致启动缓慢的原因)

文章目录 Ubuntu systemd-analyze命令剖析目录简介systemd与systemd-analyze工作原理 安装和使用命令参数详解用例与示例显示启动时间(systemd-analyze time)返回信息解读 列出启动过程中各个服务的启动时间(systemd-analyze blame&#xff0…

OpenCvSharp从入门到实践-(06)创建图像

目录 1、创建图像 1.1实例1-创建黑色图像 1.2实例2-创建白色图像 1.3实例3-创建随机像素的雪花点图像 2、图像拼接 2.1水平拼接图像 2.2垂直拼接图像 2.3实例4-垂直和水平两种方式拼接两张图像 在OpenCV中,黑白图像其实就是一个二维数组,彩色图像…

并发编程笔记

1、前言 这篇笔记是我花的20多天跟着⿊⻢的并发编程学习做的笔记,地址是b站 ⿊⻢ 并发编程 ,也是我第⼀次 学习并发 编程,所以如果有很多原理讲不清楚的,知识点不全的,也请多多包涵 中间多数图都是直接截⽼师的笔记…

Fiddler抓包工具之fiddler的composer可以简单发送http协议的请求

一,composer的详解 右侧Composer区域,是测试接口的界面: 相关说明: 1.请求方式:点开可以勾选请求协议是get、post等 2.url地址栏:输入请求的url地址 3.请求头:第三块区域可以输入请求头信息…

全球与中国工业冰箱市场:增长趋势、竞争格局与前景展望

工业制冷机是用来维持储运容器内低温冷藏环境,以防止食品饮料、药品、化学品等对温度敏感的产品腐败变质的系统。此外,冷冻机、热交换器等冷冻系统也用于在工业机械运作过程中保持冷却。工业冷冻系统的需求成长主要是由食品和饮料产业的成长所推动的。 冷…

蓝桥杯每日一题2023.12.3

题目描述 1.移动距离 - 蓝桥云课 (lanqiao.cn) 题目分析 对于此题需要对行列的关系进行一定的探究,所求实际上为曼哈顿距离,只需要两个行列的绝对值想加即可,预处理使下标从0开始可以更加明确之间的关系,奇数行时这一行的数字需…

排序算法介绍(一)插入排序

0. 简介 插入排序(Insertion Sort) 是一种简单直观的排序算法,它的工作原理是通过构建有序序列,对于未排序数据,在已排序序列中从后向前扫描,找到相应位置并插入。插入排序在实现上,通常…

计算机组成学习-指令系统总结

复习本章时,思考以下问题: 1)什么是指令?什么是指令系统?为什么要引入指令系统?2)一般来说,指令分为哪些部分?每部分有什么用处?3)对于一个指令系统来说,寻址方式多和少…

龙芯loongarch64服务器编译安装maturin

前言 maturin 是一个构建和发布基于 Rust 的 Python 包的工具,但是在安装maturin的时候,会出现如下报错:error: cant find Rust compiler 这里记录问题解决过程中遇到的问题: 1、根据错误的提示需要安装Rust Compiler,首先去其官网按照官网给的解决办法提供进行安装 curl…

SocialSelling社交销售1+5+1方法论系列:社交销售价值何在

社交销售作为差异化的社媒社交营销模式,在实际运营操作过程中需要实现个人、部门与系统的融合打磨,并非只是应用工具那么简单。为了更高效地服务赋能客户,傲途基于实战提炼出社交销售151方法论,在近期系列内容中,我们将…

浅谈安科瑞无线测温设备在挪威某项目的应用

摘要:安科瑞无线温度设备装置通过无线温度收发器和各无线温度传感器直接进行温度值的传输,并采用液晶显示各无线温度传感器所测温度。 Absrtact:Acre wireless temperature device directly transmits the temperature value through the wireless temp…