深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()d2l.use_svg_display()# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/196321.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVAEE 初阶 多线程基础(一)

多线程基础 一.线程的概念二.为什么要有线程三.进程和线程的区别和关系四.JAVA的线程和操作系统线程的关系五.第一个多线程程序1.继承Thread类 一.线程的概念 一个线程就是一个 “执行流”. 每个线程之间都可以按照顺讯执行自己的代码. 多个线程之间 “同时” 执行着多份代码 同…

CV计算机视觉每日开源代码Paper with code速览-2023.11.14

点击CV计算机视觉,关注更多CV干货 论文已打包,点击进入—>下载界面 点击加入—>CV计算机视觉交流群 1.【基础网络架构:Transformer】Aggregate, Decompose, and Fine-Tune: A Simple Yet Effective Factor-Tuning Method for Vision…

流媒体协议

◆ RTP(Real-time Transport Protocol),实时传输协议。 ◆ RTCP(Real-time Transport Control Protocol),实时传输控制协议。 ◆ RTSP(Real Time Streaming Protocol),实时流协议。 ◆ RTMP(Real Time Messaging Protocol),实时…

【Proteus仿真】【Arduino单片机】LM35温度计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器,使用PCF8574、LCD1602液晶、LM35传感器等。 主要功能: 系统运行后,LCD1602显示传感器检测温度。 二、软件设计 /* 作者&a…

单片机的冷启动、热启动、复位

一文看懂STC单片机冷启动和复位有什么区别-电子发烧友网 单片机的冷启动、热启动和复位是不同的启动或重置方式,它们在系统状态和初始化方面有所不同: 1.冷启动(Cold Start): 定义: 冷启动是指系统从完全关…

【火炬之光-魔灵装备】

文章目录 装备天赋追忆石板技能魂烛刷图策略 装备 头部胸甲手套鞋子武器盾牌项链戒指腰带神格备注盾牌其余的装备要么是召唤物生命,要么是技能等级,鞋子的闪电技能等级加2不是核心,腰带的话主要是要冷却有冷却暗影的技能是不会断的&#xff…

揭示CDN加速的局限性与探讨其小众化原因

在网络加速领域,CDN(内容分发网络)被认为是提升性能的关键技术之一。然而,尽管其在某些方面表现出色,CDN在广泛应用中仍然相对小众。本文将从CDN加速的局限性出发,深入探讨为何这项技术尚未迎来大规模的应用…

.NET 8.0 中有哪些新的变化?

1性能提升 .NET 8在整个堆栈中带来了数千项性能改进 。默认情况下会启用一种名为动态配置文件引导优化 (PGO) 的新代码生成器,它可以根据实际使用情况优化代码,并且可以将应用程序的性能提高高达 20%。现在支持的 AVX-512 指令集能够对 512 位数据向量执…

计算机毕业设计选题推荐-掌心办公微信小程序/安卓APP-项目实战

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

竞赛选题 疫情数据分析与3D可视化 - python 大数据

文章目录 0 前言1 课题背景2 实现效果3 设计原理4 部分代码5 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 大数据全国疫情数据分析与3D可视化 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐&#xff0…

websocket详解

一、什么是Websocket WebSocket 是一种在单个 TCP 连接上进行 全双工 通信的协议,它可以让客户端和服务器之间进行实时的双向通信。 WebSocket 使用一个长连接,在客户端和服务器之间保持持久的连接,从而可以实时地发送和接收数据。 在 Web…

Alibaba Nacos注册中心实战

为什么需要注册中心 思考:网络请求,如果服务提供者发生变动,服务调用者如何感知服务提供者的ip和端口变化? // 微服务之间通过RestTemplate调用,ip:port写死,如果ip或者port变化呢? String ur…

DRF纯净版项目搭建和配置

一、安装模块和项目 1.安装模块 pip install django pip install djangorestframework pip install django-redis # 按需安装 2.开启项目和api (venv) PS D:\pythonProject\env_api> django-admin startproject drf . (venv) PS D:\pythonProject\env_api> python ma…

elementui 实现树形控件单选

实现&#xff1a; <!--author: itmacydesc: 树节点单选 --> <template><div class"about"><el-tree :data"data"ref"tree":props"defaultProps"node-key"id"show-checkboxcheck-strictlycheck-change…

第七部分:Maven(项目管理工具)

目录 Maven简介 7.1&#xff1a;为什么学习Maven&#xff1f; 7.1.1、Maven是一个依赖管理工具 7.1.2&#xff1a;Maven是一个构建工具 7.1.3&#xff1a;结论 7.2&#xff1a;Maven介绍 7.3&#xff1a;Maven的优点 Maven安装和配置 7.4&#xff1a;安装教程及环境配置 …

记一次服务器配置文件获取OSS

一、漏洞原因 由于网站登录口未做双因子校验,导致可以通过暴力破解获取管理员账号,成功进入系统;未对上传的格式和内容进行校验,可以任意文件上传获取服务器权限;由于服务器上配置信息,可以进一步获取数据库权限和OSS管理权限。二、漏洞成果 弱口令获取网站的管理员权限通…

科研学习|研究方法——python T检验

一、单样本T检验 目的&#xff1a;检验单样本的均值是否和已知总体的均值相等前提条件&#xff1a; &#xff08;1&#xff09;总体方差未知&#xff0c;否则就可以利用 Z ZZ 检验&#xff08;也叫 U UU 检验&#xff0c;就是正态检验&#xff09;&#xff1b; &#xff08;2&a…

vscode 配置 lua

https://luabinaries.sourceforge.net/ 官网链接 主要分为4个步骤 下载压缩包&#xff0c;然后解压配置系统环境变量配置vscode的插件测试 这里你可以选择用户变量或者系统环境变量都行。 不推荐空格的原因是 再配置插件的时候含空格的路径 会出错&#xff0c;原因是空格会断…

纯CSS自定义滚动条样式

.my-carousel{height: 474px;overflow-y: auto; } /*正常情况下滑块的样式*/ .my-carousel::-webkit-scrollbar {width: 5px; } .my-carousel::-webkit-scrollbar-thumb {border-radius: 8px;background-color: #ccc; } .my-carousel::-webkit-scrollbar-track {border-radius:…

zabbix告警 邮件告警 钉钉告警

邮件告警添加主机组添加模板添加主机在模板中添加监控项在模板中添加触发器添加动作&#xff0c;远程执行命令给用户绑定告警媒介类型 钉钉告警安装python依赖模块python-requests配置钉钉告警配置脚本zabbix_ding.conf在目录/var/log/zabbix中创建钉钉告警日志文件zabbix_ding…