Pytorch从零开始实战02

Pytorch从零开始实战——彩色图像识别

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——彩色图像识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 数据可视化

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。
老规矩,还是导入常用包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本文使用CIFAR-10数据集,CIFAR-10数据集包含60000张 32x32的彩色图片,共分为10种类别,每种类别6000张。其中训练集包含50000张图片,测试机包含10000张图片。
我们使用torchvision.datasets下载数据集。

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True)

随机展示五张图片,下面这个函数是从别人那里抄的,挺好用的。

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_ds)

展示出现有点模糊
在这里插入图片描述
还是之前的套路,使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True)

查看dataloader的一个批次

imgs, label = next(iter(train_dl))
imgs.shape # 3通道 32 * 32

模型选择

本文还是使用自定义简单的卷积神经网络,下面是两个函数原型,卷积函数默认步幅是1,padding是0,最大池化函数默认步长跟最大池化窗口大小一致。

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

最后输出的图像尺寸是2 * 2,128通道,将它们拉平,线性层需要填入512来接收。
在这里插入图片描述

num_classes = 10
class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(64, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(64, 128, kernel_size=3)self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.pool3(F.relu(self.conv3(x)))x = torch.flatten(x, start_dim=1) # 从第一个维度开始拉平x = F.relu(self.fc1(x))x = self.fc2(x)return x

使用summary查看模型,这次用的是torchsummary,展示的也不错,可以看见每层结构。

from torchsummary import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model, input_size=(3, 32, 32))

在这里插入图片描述

模型训练

定义损失函数、学习率、优化算法。

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

跟以前的文章一样,编写训练函数。

def train(dataloader, model, loss_fn, opt):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)opt.zero_grad()loss.backward()opt.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

测试函数。

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

开始训练,这次epochs设置为10轮。

import time
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []T1 = time.time()for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")
T2 = time.time()
print('程序运行时间:%s毫秒' % ((T2 - T1)*1000))

运行结果:在训练集上准确率达到59.7%,测试集达到59.9%,效果不是很好,说明还是欠拟合,需要未来改造模型或者增加训练轮数。
在这里插入图片描述

数据可视化

使用matplotlib进行训练数据、测试数据的可视化

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/133675.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Day 01 web前端基础知识

首先我们要了解什么事前端? 先简单用文字介绍一下: 一、入门知识 Web前端是指网站或应用程序的用户界面部分。它包括HTML、CSS、JavaScript等语言和技术,用于创建用户可浏览和交互的网页。Web前端的特点在于其交互性和动态性,可…

非常详细的trunk-based分支管理流程配置及使用

非常详细的trunk-based分支管理流程配置及使用。 目前业界主流的版本管理流程是Gitflow 和 trunk-based。 Gitflow流行的比较早。但是目前的流行度要低于 trunk-based模式工作流。trunk-based模式被誉为是现代化持续集成的最佳实践。 他俩的核心区别是,Gitflow是一个更严格…

如何优化你的Vue.js应用以获得最佳性能

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

vue3+ts+uniapp小程序封装获取授权hook函数

vue3tsuniapp小程序封装获取授权hook函数 小程序授权的时候,如果点击拒绝授权,然后就再也不会出现授权了,除非用户手动去右上角…设置打开 通过uni官方api自己封装一个全局的提示: uni.getSetting :http://uniapp.dcloud.io/api/other/settin…

前端综合练手小项目

导读 本篇文章主要以小项目的方式展开,其中给出的代码中均包含详细地注释,大家可以参照理解。下面4个小项目中均包含有 HTML、CSS、JavaScript 等相关知识,可以拿来练手,系统提升一下自己的前端开发能力。 废话少说,…

《Envoy 代理:云原生时代的流量管理》

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🐅🐾猫头虎建议程序员必备技术栈一览表📖: 🛠️ 全栈技术 Full Stack: &#x1f4da…

SCRUM产品负责人(CSPO)认证培训课程

课程简介 Scrum是目前运用最为广泛的敏捷开发方法,是一个轻量级的项目管理和产品研发管理框架。产品负责人是Scrum的三个角色之一,产品负责人在Scrum产品开发当中扮演舵手的角色,他决定产品的愿景、路线图以及投资回报,他需要回答…

微服务如何改变软件开发:实战经验与最佳实践分享

文章目录 什么是微服务?微服务实战经验1. 定义明确的服务边界2. 使用API网关3. 自动化部署和持续集成4. 监控和日志记录 微服务最佳实践1. 文档和通信2. 弹性设计3. 安全性4. 版本控制5. 监控和警报 微服务的未来 🎉欢迎来到架构设计专栏~微服务如何改变…

C++版本的OpenCV实现二维图像的卷积定理(通过傅里叶变换实现二维图像的卷积过程,附代码!!)

C版本的OpenCV库实现二维图像的卷积定理过程详解 前言一、卷积定理简单介绍二、不同卷积过程对应的傅里叶变换过程1、“Same”卷积2、“Full”卷积3、“Valid”卷积 三、基于OpenCV库实现的二维图像卷积定理四、基于FFTW库实现的二维图像卷积定理五、总结与讨论 前言 工作中用…

计算机视觉--距离变换算法的实战应用

前言: Hello大家好,我是Dream。 计算机视觉CV是人工智能一个非常重要的领域。 在本次的距离变换任务中,我们将使用D4距离度量方法来对图像进行处理。通过这次实验,我们可以更好地理解距离度量在计算机视觉中的应用。希望大家对计算…

QML android 采集手机传感器数据 并通过udp 发送

利用 qt 开发 安卓 app &#xff0c;采集手机传感器数据 并通过udp 发送 #ifndef UDPLINK_H #define UDPLINK_H#include <QObject> #include <QUdpSocket> #include <QHostAddress>class UdpLink : public QObject {Q_OBJECT public:explicit UdpLink(QObjec…

C#,数值计算——Hashfn1的计算方法与源程序

1 文本格式 using System; using System.Collections; using System.Collections.Generic; namespace Legalsoft.Truffer { public class Hashfn1 { private Ranhash hasher { get; set; } new Ranhash(); private int n { get; set; } public Hash…

【C++11保姆级教程】列表初始化(Literal types)和委派构造函数(delegating))

文章目录 前言一、列表初始化 (List Initialization)1.1数组初始化1.2结构体初始化1.3容器初始化1.4列表初始化的优势 二、委派构造函数 (Delegating Constructors)2.1委派构造函数是什么&#xff1f;2.2委派构造函数示例代码2.3调用顺序2.3委派构造函数优势 总结 前言 C11引入…

SpringCLoud——Nacos配置中心

Nacos实现配置管理 统一配置管理 配置更新热更新 统一配置的创建是在UI界面中完成的&#xff1a; 首先我们点击【配置管理】然后点击【配置列表】&#xff1a; 然后我们就看到了配置管理界面&#xff0c;但是此时这里是空的&#xff0c;我们可以创建一些配置文件&#xff1a…

Springboot 实践(17)spring boot整合Nacos配置中心

前文我们讲解了Nacos服务端的下载安装&#xff0c;本文我们降价spring boot整合nacos&#xff0c;实现Nacos服务器配置参数的访问。 一、启动Nacos服务&#xff0c;创建三个配置文件&#xff0c;如下所示 Springboot-Nacos-Client-dev.yaml文件配置参数 Springboot-Nacos-Clie…

微信小程序自动化发布

参考:https://developers.weixin.qq.com/miniprogram/dev/devtools/ci.html 参考:https://www.npmjs.com/package/miniprogram-ci npm install miniprogram-ci -S上传文件 xx.js const isNodeJs typeof process ! undefined && process.release ! null &&…

智能文本纠错API的应用与工作原理解析

引言 在数字时代&#xff0c;文本撰写和传播变得日益重要&#xff0c;无论是在学校里写论文、在职场中发送邮件&#xff0c;还是在社交媒体上发表观点。然而&#xff0c;文字错误、标点符号错误、语法问题和不当的表达常常会削弱文本的质量&#xff0c;降低信息传达的效果。为…

数据结构-leetcode-数组形式的整数加法

解题图解&#xff1a; 下面是代码&#xff1a; /*** Note: The returned array must be malloced, assume caller calls free().*/ int* addToArrayForm(int* num, int numSize, int k, int* returnSize){int k_tem k;int klen0;while(k_tem){//看看k有几位k_tem /10;klen;}i…

第三节:在WORD为应用主窗口下关闭EXCEL的操作(2)

【分享成果&#xff0c;随喜正能量】凡事好坏&#xff0c;多半自作自受&#xff0c;既不是神为我们安排&#xff0c;也不是天意偏私袒护。业力之前&#xff0c;机会均等&#xff0c;毫无特殊例外&#xff1b;好坏与否&#xff0c;端看自己是否能应机把握&#xff0c;随缘得度。…

免费和开源的机器翻译软件LibreTranslate

什么是 LibreTranslate &#xff1f; LibreTranslate 免费开源机器翻译 API&#xff0c;完全自托管。与其他 API 不同&#xff0c;它不依赖于 Google 或 Azure 等专有提供商来执行翻译。它的翻译引擎由开源 Argos Translate 库提供支持。 这个软件在 2022 年 3 月的时候折腾过&…