卷积神经网络实现天气图像分类 - P3

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P3周:彩色图片识别:天气识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
    • 数据准备
    • 模型设计
    • 模型训练
    • 结果展示
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

首先是包引用

import torch # pytorch主包
import torch.nn as nn # 模型相关的包,创建一个别名少打点字
import torch.optim as optim # 优化器包,创建一个别名
import torch.nn.functional as F # 可以直接调用的函数,一般用来调用里面在的激活函数from torch.utils.data import DataLoader, random_split # 数据迭代包装器,数据集切分
from torchvision import datasets, transforms # 图像类数据集和图像转换操作函数import matplotlib.pyplot as plt # 图表库
from torchinfo import summary # 打印模型结构

查询当前环境的GPU是否可用

print(torch.cuda.is_available())

GPU可用情况
创建一个全局的设备对象,用于使各类数据处于相同的设备中

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 当GPU不可用时,使用CPU# 如果是Mac系统可以多增加一个if条件,启用mps
if torch.backends.mps.is_available():device = torch.device('mps')

数据准备

这次的天气图像是由K同学提供的,我提前下载下来放在了当前目录下的data文件夹中
加载文件夹中的图像数据集,要求文件夹按照不同的分类并列存储,一个简要的文件树为

datacloudyrainshinesunrise

使用torchvisio.datasets中的方法加载自定义图像数据集,可以免除一些文章中推荐的自己创建Dataset,个人感觉十分方便,而且这种文件的存储结构也兼容keras框架。

首先我们使用原生的PythonAPI来遍历一下文件夹,收集一下分类信息

import pathlibdata_lib = pathlib.Path('data')
class_names = [f.parts[-1] for f in data_lib.glob('*')] # 将data下级文件夹作为分类名
print(class_names)

打印分类信息
在所有的图片中随机选择几个文件打印一下信息。

import numpy as np
from PIL import Image
import randomimage_list = list(data_lib.glob('*/*'))
for _ in range(10):print(np.array(Image.open(random.choice(image_list))).shape)

打印图像信息
通过打印图像信息,发现图像的大小并不一致,需要在创建数据集时对图像进行缩放到统一的大小。

transform = transforms.Compose([transforms.Resize([224, 224]), # 将图像都缩放到224x224transforms.ToTensor(), # 将图像转换成pytorch tensor对象
]) # 定义一个全局的transform, 用于对齐训练验证以及测试数据

接下来就可以正式从文件夹中加载数据集了

dataset = datasets.ImageFolder('data', transform=tranform)

现在把整文件夹下的所有文件加载为了一个数据集,需要根据一定的比例划分为训练和验证集,方便模型的评估

train_size = int(len(dataset) *0.8) # 80% 训练集 20% 验证集
eval_size = len(dataset) - train_sizetrain_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

创建完数据集,打印一下数据集中的图像

plt.figure(figsize=(20, 4))
for i in range(20):image, label = train_dataset[i]plt.subplot(2, 10, i+1)plt.imshow(image.permute(1,2,0)) # pytorch的tensor格式为N,C,H,W,在imshow展示需要将格式变成H,W,C格式,使用permute切换一下plt.axis('off')plt.title(class_names[label])

预览数据集
最后用DataLoader包装一下数据集,方便遍历

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_loader, batch_size=batch_size)

模型设计

使用一个带有BatchNorm的卷积神经网络来处理分类问题

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 12, kernel_size=5, strides=1)self.conv2 = nn.Conv2d(12, 12, kernel_size=5, strides=1)self.conv3 = nn.Conv2d(12, 24, kernel_size=5, strides=1)self.conv4 = nn.Conv2d(24, 24, kernel_size=5, strides=1)self.maxpool = nn.MaxPool2d(2)self.bn1 = nn.BatchNorm2d(12)self.bn2 = nn.BatchNorm2d(12)self.bn3 = nn.BatchNorm2d(24)self.bn4 = nn.BatchNorm2d(24)# 224 [-> 220 -> 216 -> 108] [-> 104 -> 100 -> 50]self.fc1 = nn.Linear(50*50*24, num_classes)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.maxpool(x)x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc1(x)return xmodel = Network(len(class_names)).to(device) # 别忘了把定义的模型拉入共享中
summary(model, input_size=(32, 3, 224, 224))

模型结构

模型训练

首先定义一下每个epoch内训练和评估的逻辑

def train(train_loader, model, loss_fn, optimizer):train_size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= train_sizereturn train_loss, train_accdef eval(eval_loader, model, loss_fn):eval_size = len(eval_loader.dataset)num_batches = len(eval_loader)eval_loss, eval_acc = 0, 0for x, y in eval_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)eval_loss += loss.item()eval_acc += (preds.argmax(1) == y).type(torch.float).sum().item()eval_loss /= num_batcheseval_acc /= eval_sizereturn eval_loss, eval_acc

然后编写代码进行训练

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 10train_loss, train_acc = [], []
eval_loss, eval_acc =[], []
for epoch in range(epochs):model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval()model.no_grad():epoch_eval_loss, epoch_eval_acc = test(eval_loader, model, loss_fn)

结果展示

训练结果
基于训练和测试数据展示结果

range_epochs = range(len(train_loss))
plt.figure(figsize=(12, 4))
plt.subplot(1,2,1)
plt.plot(range_epochs, train_loss, label='train loss')
plt.plot(range_epochs, eval_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.subplot(1,2,2)
plt.plot(range_epochs, train_acc, label='train accuracy')
plt.plot(range_epochs, eval_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练历史图表

总结与心得体会

通过对训练过程的观察,训练过程中的数据波动很大,并且验证集上的最好正确率只有82%。
目前行业都流行小卷积核,于是我把卷积核调整为了3x3,并且每次卷积后我都进行池化操作,直到通道数为64,由于天气识别时,背景信息也比较重要,高层的卷积操作后我使用平均池化代替低层使用的最大池化,加大了全连接层的Dropout惩罚比重,用来抑制过拟合问题。最后的模型如下:

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3)self.conv2 = nn.Conv2d(16, 32, kernel_size=3)self.conv3 = nn.Conv2d(32, 64, kernel_size=3)self.conv4 = nn.Conv2d(64, 64, kernel_size=3)self.bn1 = nn.BatchNorm2d(16)self.bn2 = nn.BatchNorm2d(32)self.bn3 = nn.BatchNorm2d(64)self.bn4 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(2)self.avgpool = nn.AvgPool2d(2)self.dropout = nn.Dropout(0.5)# 224 -> 222-> 111 -> 109 -> 54 -> 52 -> 50 -> 25self.fc1 = nn.Linear(25*25*64, 128)self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.maxpool(x)x = F.relu(self.bn2(self.conv2(x)))x = self.avgpool(x)x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

然后增大训练的epochs为30,学习率降低为1e-4

optimizer = optim.Adam(model.parameters(), lr=1e-4)
epochs = 30

训练结果如下
训练过程
可以看到,验证集上的正确率最高达到了95%以上
训练过程图示

在数据集中随机选取一个图像进行预测展示

image_path = random.choice(image_list)
image_input = transform(Image.open(image_path))
image_input = image_input.unsqueeze(0).to(device)
model.eval()
pred = model(image_input)plt.figure(figsize=(5, 5))
plt.imshow(image_input.cpu().squeeze(0).permute(1,2,0))
plt.axis('off')
plt.title(class_names[pred.argmax(1)])

结果如下
预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/100572.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot + MyBatisPlus中乐观锁的实现 (精简demo)

乐观锁加注解Version后不需要手动进行加1操作。乐观锁是一种用于解决并发冲突的机制,在数据库中用于保护数据的一致性。Version注解是MyBatisPlus框架中的乐观锁注解,它会在更新数据时自动检查版本号是否一致,如果一致则进行更新操作&#xf…

如何选择 DCDC 降压型开关电源的电感

选择合适的电感是开关电源电路设计的关键之一。本文将帮助您理解电感值和电路性能之间的关系。 降压转换器(buck converter),也称为降压转换器(step-down converter),是一种开关模式稳压器(voltage regulator&#xf…

220V转5V芯片三脚芯片-AH8652

220V转5V芯片三脚芯片是一种非常常见的电源管理芯片,它通常被用于将高压交流输入转为稳定的直流5V输出。芯片型号AH8652是一款支持交流40V-265V输入范围的芯片,采用了SOT23-3三脚封装。该芯片内部集成了650V高压MOS管,能够稳定地将输入电压转…

R语言APSIM模型进阶应用与参数优化、批量模拟实践技术

随着数字农业和智慧农业的发展,基于过程的农业生产系统模型在模拟作物对气候变化的响应与适应、农田管理优化、作物品种和株型筛选、农田固碳和温室气体排放等领域扮演着越来越重要的作用。APSIM (Agricultural Production Systems sIMulator)模型是世界知名的作物生…

vite 项目搭建

1. 创建 vite 项目 npm create vite@latest 2. 安装sass/less ( 一般我使用sass ) cnpm add -D sasscnpm add -D less 3. 自动导入 两个插件 使用之后,不用导入vue中hook reactive ref cnpm install -D unplugin-vue-components unplugin-auto-import 在 vite.config.…

记录一个用C#实现的windows计时执行任务的服务

记录一个用C#实现的windows计时执行任务的服务 这个服务实现的功能是每天下午六点统计一次指定路径的文件夹大小 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Diagnostics; using System.IO; using Syst…

【实战】十一、看板页面及任务组页面开发(三) —— React17+React Hook+TS4 最佳实践,仿 Jira 企业级项目(二十五)

文章目录 一、项目起航:项目初始化与配置二、React 与 Hook 应用:实现项目列表三、TS 应用:JS神助攻 - 强类型四、JWT、用户认证与异步请求五、CSS 其实很简单 - 用 CSS-in-JS 添加样式六、用户体验优化 - 加载中和错误状态处理七、Hook&…

基于Spring Boot的智慧团支部建设网站的设计与实现(Java+spring boot+MySQL)

获取源码或者论文请私信博主 演示视频: 基于Spring Boot的智慧团支部建设网站的设计与实现(Javaspring bootMySQL) 使用技术: 前端:html css javascript jQuery ajax thymeleaf 微信小程序 后端:Java sp…

stm32的命令规则

stm32型号的说明:以STM32F103RBT6这个型号的芯片为例,该型号的组成为7个部分,其命名规则如下:

生成式AI和大语言模型 Generative AI LLMs

在“使用大型语言模型(LLMs)的生成性AI”中,您将学习生成性AI的基本工作原理,以及如何在实际应用中部署它。 通过参加这门课程,您将学会: 深入了解生成性AI,描述基于LLM的典型生成性AI生命周期中的关键步骤&#xff…

OpenAI首席执行官谈马斯克退出的影响;关于 Stable Diffusion 扩展功能的内容总结

🦉 AI新闻 🚀 OpenAI首席执行官谈马斯克退出的影响 摘要:OpenAI首席执行官萨姆奥特曼在接受采访时表示,埃隆马斯克退出OpenAI给公司带来了巨大震惊,导致奥特曼不得不重新调整生活和时间安排以确保足够资金。OpenAI是…

hive中get_json_object函数不支持解析json中文key

问题 今天在 Hive 中 get_json_object 函数解析 json 串的时候,发现函数不支持解析 json 中文 key。 例如: select get_json_object({ "姓名":"张三" , "年龄":"18" }, $.姓名);我们希望的结果是得到姓名对应…

七夕节日表白:七大网页风格与其适用人群

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

SVM详解

公式太多了,就用图片用笔记呈现,SVM虽然算法本质一目了然,但其中用到的数学推导还是挺多的,其中拉格朗日约束关于α>0这块证明我看了很长时间,到底是因为悟性不够。对偶问题也是,用了一个简单的例子才明…

文心一言 VS 讯飞星火 VS chatgpt (79)-- 算法导论7.4 4题

四、请用go语言,证明:RANDOMIZED-QUICKSORT 期望运行时间是 Ω(nlgn)。 文心一言: 首先,为了证明RANDOMIZED-QUICKSORT的期望运行时间是Ω(nlg n),我们需要证明在最坏的情况下,该算法的运行时间是O(nlg n)。然后&…

1.0的星火2.0必将燎原——图文声影PPT全测试

一、前言 大家好,勇哥又来分享AI模型了,前几天讯飞发布的星火大模型2.0迅速的进入了我们圈子里,为了有更多更好的模型分享给大家,分享星火大模型2.0是必须做的,我做一个传递着,希望大家也星火相传啊。 我…

解决执行 spark.sql 时版本不兼容的一种方式

场景描述 hive 数据表的导入导出功能部分代码如下所示,使用 assemble 将 Java 程序和 spark 相关依赖一起打成 jar 包,最后 spark-submit 提交 jar 到集群执行。 public class SparkHiveApplication {public static void main(String[] args){long sta…

操作系统——shell编程

文章目录 shell入门什么是 Shell?Shell 编程的 Hello World Shell 变量Shell 编程中的变量介绍Shell 字符串入门Shell 字符串常见操作Shell 数组 Shell 基本运算符算数运算符关系运算符逻辑运算符布尔运算符字符串运算符文件相关运算符 shell流程控制if 条件语句for…

市面上那里有稳定L2股票行情数据接口?

随着市场的发展和技术的进步,level2股票行情数据接口已经成为股票交易软件的标准配置之一。虽然这些券商软件的功能在很大程度上相似,但它们仍然有自己的特点和优势。 例如:通过股票交易所以其专业的研究报告和丰富的信息服务而受到广泛关注&…

Shell编程基础02

0目录 1.case语法 2.grep 3.sed 4.awk 5.linux安装mysql 1.case语法 创建一个txt文档 执行 查询用户名 case 用法 写一个计算器脚本 加入函数 补充查看进程命名 2.find grep命令 Find 查询当前目录下 以sh结尾的文件 Grep 查询义开头的 或者加入正则表达…