深度学习实战模拟——softmax回归(图像识别并分类)

目录

1、数据集:

2、完整代码


1、数据集:

1.1 Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

1.2 Fashion‐MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据 集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集 不会用于训练,只用于评估模型性能。

以下函数用于在数字标签索引及其文本名称之间进行转换。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

2、完整代码

import torch
import torchvision
import pylab
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
from d2l import torch as d2l
import timebatch_size = 256
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
num_epochs = 5class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)), y])def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp/partitiondef net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)def get_dataloader_workers():"""使用一个进程来读取的数据"""return 1def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""#共10个类别text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""画一系列图片"""figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)for i, (img, label) in enumerate(zip(imgs, titles)):xloc, yloc = i//num_cols, i % num_colsif torch.is_tensor(img):# 图片张量axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())else:# PIL图片axes[xloc, yloc].imshow(img)# 设置标题并取消横纵坐标上的刻度axes[xloc, yloc].set_title(label)plt.xticks([], ())axes[xloc, yloc].set_axis_off()pylab.show()def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = transforms.ToTensor()if resize:trans.insert(0, transforms.Resize(resize))mnist_train = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))def evaluate_accuracy(net, data_iter):"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]def updater(batch_size):lr = 0.1return d2l.sgd([W, b], lr, batch_size)def train_epoch_ch3(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()metric = Accumulator(3)for X, y in train_iter:y_hat = net(X)lo = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()lo.backward()updater.step()metric.add(float(lo)*len(y), accuracy(y_hat, y), y.size().numel())else:lo.sum().backward()updater(X.shape[0])metric.add(float(lo.sum()), accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]class Animator:  #@save"""绘制数据"""def __init__(self, legend=None):self.legend = legendself.X = [[], [], []]self.Y = [[], [], []]def add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nfor i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)def show(self):plt.plot(self.X[0], self.Y[0], 'r--')plt.plot(self.X[1], self.Y[1], 'g--')plt.plot(self.X[2], self.Y[2], 'b--')plt.legend(self.legend)plt.xlabel('epoch')plt.ylabel('value')plt.title('Visual')plt.show()def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型"""animator = Animator(legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)train_loss, train_acc = train_metricstest_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))print(f'epoch: {epoch+1},train_loss:{train_loss:.4f}, train_acc:{train_acc:.4f}, test_acc:{test_acc:.4f}')animator.show()def predict_ch3(net, test_iter, n=12):"""预测标签"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]show_images(X[0:n].reshape((n, 28, 28)), 2, int(n/2), titles=titles[0:n])if __name__ == '__main__':train_iter, test_iter = load_data_fashion_mnist(batch_size)train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)predict_ch3(net, test_iter)

分类效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/278689.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文笔记:Llama 2: Open Foundation and Fine-Tuned Chat Models

导语 Llama 2 是之前广受欢迎的开源大型语言模型 LLaMA 的新版本,该模型已公开发布,可用于研究和商业用途。本文记录了阅读该论文的一些关键笔记。 链接:https://arxiv.org/abs/2307.09288 1 引言 大型语言模型(LLMs&#xff…

深入探索C与C++的混合编程

实现混合编程的技术细节 混合使用C和C可能由多种原因驱动。一方面,现有的大量优秀C语言库为特定任务提供了高效的解决方案,将这些库直接应用于C项目中可以节省大量的开发时间和成本。另一方面,C的高级特性如类、模板和异常处理等,…

ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案

ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案 文章目录 ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案摘要📑引言 🌟正文📚一、ONLYOFFICE文档概述 📊二、ONLYOFFI…

2024同城招标投标微网站微信小程序版本源码

2024同城招标投标微网站&微信小程序版本源码 功能介绍: 同城招投标这套程序主要是为了解决招投标问题 用户缴纳保证金发布起来招标,然后商家进行认证成功后可以对招标发起投标,投标过程也需要缴纳保证金,单招标结束或者下架保…

第五篇:数字视频广告格式概述 - IAB视频广告标准《数字视频和有线电视广告格式指南》

第五篇:第五篇:数字视频广告格式概述 - IAB视频广告标准《数字视频和有线电视广告格式指南 --- 我为什么要翻译介绍美国人工智能科技公司IAB系列技术标准(2) ​​​​​​​翻译计划 第一篇序言第二篇简介和目录第三篇概述- IA…

完成系统支持Github三方登录

文章目录 1、需求2、在对接系统中完成客户端注册3、创建客户端应用4、CommonOAuth2Provider SpringSecurity OAuth2.0文档: https://docs.spring.io/spring-security/reference/servlet/oauth2/index.html 1、需求 对接Github,在自己系统实现支持Githu…

Mit6.s081 前置开发环境: 虚拟机ubuntu + ssh + vscode

虚拟机 ssh vscode 前置条件 下载VMware Download VMware Workstation ProUbuntuUbuntu系统下载 | Ubuntu vscode Visual Studio Code - Code Editing. Redefined Ubuntu版本:20.04 Ubuntu基本操作 ubuntu 安装 ssh 服务 sudo apt-get install openssh-serv…

TT-100K数据集,YOLO格式

TT-100K数据集YOLO格式,分为train、val和test,其中train中共有6793张图片,val中共有1949张图片,test中共有996张图片。数据集只保留包含图片数超过100的类别。共计46类。

基于支持向量机(svm)的人脸识别

注意:本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 ([www.aideeplearning.cn]) 数据集加载与可视化 from sklearn.datasets import fetch_lfw_people faces fetch_lfw_people(min_faces_per_person60) # Check out sample…

MS12_020 漏洞利用与安全加固

文章目录 环境说明1 MS12_020 简介2 MS12_020 复现过程3 MS12_020 安全加固 环境说明 渗透机操作系统:kali-linux-2024.1-installer-amd64漏洞复现操作系统: cn_win_srv_2003_r2_enterprise_with_sp2_vl_cd1_X13-46432 1 MS12_020 简介 MS12_020 漏洞全称为&#x…

【视频异常检测】Delving into CLIP latent space for Video Anomaly Recognition 论文阅读

Delving into CLIP latent space for Video Anomaly Recognition 论文阅读 ABSTRACT1. Introduction2. Related Works3. Proposed approach3.1. Selector model3.2. Temporal Model3.3. Predictions Aggregation3.4. Training 4. Experiments4.1. Experiment Setup4.2. Evaluat…

linux安装WordPress问题汇总,老是提示无法连接到FTP服务器解决方案

最近在做一些建站相关的事情,遇到一些大大小小的问题都整理在这里 1.数据库密码和端口,千万要复杂一点,不要使用默认的3306端口 2.wordpress算是一个php应用吧,所以安装流程一般是 apache http/nginx——php——mysql——ftp &…

ASP.NET使用Applicaiton状态来存储和检索数据

目录 背景: 实例: Button1事件: Button2事件: 效果展示: 总结: 背景: 在ASP.NET Web From应用程序中,Appliciton是一个内置对象,用于在整个Web应用程序范围内存储和检索数据,这意味着存储在Application对象中的数据可以被应用程序中的…

算法笔记p154最大公约数和最小公倍数

目录 最大公约数辗转相除法证明例子代码实现 最小公倍数代码实现 最大公约数 正整数a与b的最大公约数是指a与b的所有公约数中最大的那个公约数,一般用gcd(a, b)表示a和b的最大公约数。 辗转相除法 设a、b均为正整数,则gcd(a, b) gcd(b, a % b)。即被…

huawei 华为交换机 配置手工模式链路聚合示例

组网需求 如 图 3-21 所示, SwitchA 和 SwitchB 通过以太链路分别都连接 VLAN10 和 VLAN20 的网络,SwitchA 和 SwitchB 之间有较大的数据流量。 用户希望SwitchA 和 SwitchB 之间能够提供较大的链路带宽来使相同 VLAN 间互相通信。 同时用户也希望能够提…

数据结构 二叉树 力扣例题AC——代码以及思路记录

LCR 175. 计算二叉树的深 某公司架构以二叉树形式记录,请返回该公司的层级数。 AC int calculateDepth(struct TreeNode* root) {if (root NULL){return 0;}else{return 1 fmax(calculateDepth(root->left), calculateDepth(root->right));} } 代码思路 …

华为配置终端定位基本实验配置

配置终端定位基本示例 组网图形 图1 配置终端定位基本服务示例 组网需求数据准备配置思路配置注意事项操作步骤配置文件 组网需求 如图1所示,某公司网络中,中心AP直接与RU连接。 管理员希望通过RU收集Wi-Fi终端信息,并提供给定位服务器进行定…

基于单片机的家庭烟雾报警系统

摘要:本文主要针对家庭等小型应用场所, 提出基于以单片机CC2530 作为控制器的智能烟雾报警系统,通过MQ-2 气体传感器来检测烟雾浓度,在单片机的A/D模块转化后,并配合蜂鸣元器件实现声音报警功能。 【关键词】烟雾报警 单片机 烟雾传感器 由于科技的发展以及各类家电走入…

【博客7.4】缤果Qt5_TWS串口调试助手V2.0 (高级篇)

超级好用的Qt5_TWS耳机串口调试助手 开发工具: qt-opensource-windows-x86-5.14.2 (编程语言C) 目录 前言 一、软件概要: 二、软件界面: 1.App演示 三、获取 >> 源码以及Git记录: 总结 前言 串口调试助手支持常用的50bps - 10M…

机器人在果园内行巡检仿真

文章目录 创建工作空间仿真果园场景搭建小车模型搭建将机器人放在仿真世界中创建工作空间 mkdir -p ~/catkin_ws/src cd ~/catkin_ws仿真果园场景搭建 cd ~/catkin_ws/src git clone https://gitcode.com/clearpathrobotics/cpr_gazebo.git小车模型搭建 DiffBot是一种具有两个…