【深度学习】四种天气分类 模版函数 从0到1手敲版本

引入该引入的库

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import torch.optim as optim
%matplotlib inline
import os
import shutil
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

注意:os.environ[“KMP_DUPLICATE_LIB_OK”]=“TRUE” 必须要引入否则用plt出错

数据集整理

img_dir = r"F:\播放器\1、pytorch全套入门与实战项目\课程资料\参考代码和部分数据集\参考代码\参考代码\29-42节参考代码和数据集\四种天气图片数据集\dataset2"
base_dir = r"./dataset/4weather"img_list = glob.glob(img_dir+"/*.*")
test_dir = "test"
train_dir = "train"
species = ["cloudy","rain","shine","sunrise"]
for idx,img_path in enumerate(img_list):_,img_name = os.path.split(img_path)if idx%5==0:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(test_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)else:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(train_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)shutil.copy(img_path,dst_path)

生成测试和训练的文件夹,
目录结构如下:
在这里插入图片描述
rain 下面就是图片了
在这里插入图片描述

构建ds和dl

from torchvision import transforms
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
train_ds=torchvision.datasets.ImageFolder(train_dir,transform)
test_ds = torchvision.datasets.ImageFolder(train_dir,transform)

在这里插入图片描述
在这里插入图片描述
一张图片效果,这是rain图片 这里需要转换维度,把channel放到最后。同时把数据拉到0-1之间,原本std 和mean 【0.5,0,5】数据在-0.5~0.5之间
在这里插入图片描述
类的映射
在这里插入图片描述

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):img = (img.permute(1, 2, 0).numpy() + 1)/2plt.subplot(2, 3, i+1)plt.title(id_to_class.get(label.item()))plt.imshow(img)

这个方法要学会
在这里插入图片描述

定义网络

class Net(nn.Module):def __init__(self) -> None:super().__init__()self.conv1 = nn.Conv2d(3,16,3)self.conv2 = nn.Conv2d(16,32,3)self.conv3 = nn.Conv2d(32,64,3)self.pool = nn.MaxPool2d(2,2)self.dropout = nn.Dropout(0.3)self.fc1 = nn.Linear(64*10*10,1024)self.fc2 = nn.Linear(1024,4)def forward(self,x):x = F.relu(self.conv1(x))x = self.pool(x)x = F.relu(self.conv2(x))x = self.pool(x)x = F.relu(self.conv3(x))x = self.pool(x)x = self.dropout(x)# print(x.size()) 这里是可以计算出来的,需要掌握计算方法x = x.view(-1,64*10*10)x = F.relu(self.fc1(x))x = self.dropout(x)return self.fc2(x)
model = Net()        
preds = model(imgs)
preds.shape, preds

在这里插入图片描述
定义损失函数和优化函数:

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr=0.001)

定义网络

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0for x, y in trainloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)optim.zero_grad()loss.backward()optim.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim=1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()epoch_loss = running_loss / len(trainloader.dataset)epoch_acc = correct / totaltest_correct = 0test_total = 0test_running_loss = 0 with torch.no_grad():for x, y in testloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)y_pred = torch.argmax(y_pred, dim=1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_loss = test_running_loss / len(testloader.dataset)epoch_test_acc = test_correct / test_totalprint('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy:', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy:', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

训练:

epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,model,train_dl,test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)
epoch:  0 loss:  0.043 accuracy: 0.714 test_loss:  0.029 test_accuracy: 0.809
epoch:  1 loss:  0.03 accuracy: 0.807 test_loss:  0.023 test_accuracy: 0.867
epoch:  2 loss:  0.024 accuracy: 0.857 test_loss:  0.018 test_accuracy: 0.888
epoch:  3 loss:  0.021 accuracy: 0.869 test_loss:  0.017 test_accuracy: 0.894
epoch:  4 loss:  0.018 accuracy: 0.886 test_loss:  0.014 test_accuracy: 0.921
epoch:  5 loss:  0.017 accuracy: 0.897 test_loss:  0.022 test_accuracy: 0.869
epoch:  6 loss:  0.013 accuracy: 0.923 test_loss:  0.008 test_accuracy: 0.944
epoch:  7 loss:  0.009 accuracy: 0.947 test_loss:  0.011 test_accuracy: 0.924
epoch:  8 loss:  0.006 accuracy: 0.966 test_loss:  0.004 test_accuracy: 0.988
epoch:  9 loss:  0.004 accuracy: 0.979 test_loss:  0.002 test_accuracy: 0.998
epoch:  10 loss:  0.004 accuracy: 0.979 test_loss:  0.005 test_accuracy: 0.966

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
比较重要的点,
1.分类的数据集布局要记住
2.图片经过conv2 多次后的值要会算 todo
3.图片展示的方法要会

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/284219.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQLiteC/C++接口详细介绍sqlite3_stmt类(六)

返回:SQLite—系列文章目录 上一篇:SQLiteC/C接口详细介绍sqlite3_stmt类(五) 下一篇: SQLiteC/C接口详细介绍sqlite3_stmt类(七) 17. sqlite3_clear_bindings函数 sqlite3_clear_bindings函…

微服务高级篇(一):微服务保护+Sentinel

文章目录 一、初识Sentinel1.1 雪崩问题及解决方案1.2 微服务保护技术对比1.3 Sentinel介绍与安装1.4 微服务整合Sentinel 二、Sentinel的流量控制三、Sentinel的隔离与降级四、Sentinel的授权规则五、规则持久化5.1 规则管理模式【原始模式、pull模式、push模式】5.2 实现push…

Spark-Scala语言实战(4)

在之前的文章中,我们学习了如何在scala中定义无参,带参以及匿名函数。想了解的朋友可以查看这篇文章。同时,希望我的文章能帮助到你,如果觉得我的文章写的不错,请留下你宝贵的点赞,谢谢。 Spark-Scala语言…

基础:TCP四次挥手做了什么,为什么要挥手?

1. TCP 四次挥手在做些什么 1. 第一次挥手 : 1)挥手作用:主机1发送指令告诉主机2,我没有数据发送给你了。 2)数据处理:主机1(可以是客户端,也可以是服务端)&#xff0c…

LM studio使用gemmar聊天小试

通过LM studio可以方便的使用各种模型,使用LM提供的chat界面或者是使用python代码。 试试代码 在windows下使用python简单一试,例子直接复制LM界面上的代码: 用pip安装 openai包在LM界面 Start Server 需要安装 openai包。 本地电脑是I7…

图像处理ASIC设计方法 笔记12 图像旋转ASIC中心控制器状态机

P109 1 流水线图像旋转ASIC整体架构 中心控制器负责各个模块的状态控制和数据调度,接收到外部启动信号后,进人芯片初始化阶段,片上FIFO接收外部输入的图像旋转参数、接收完毕后,再利用接收到的旋转角度到查找表中找到对应的正弦和正切值。 中心控制器将接收到的行列信息…

计算机网络——26通用转发和SDN

通用转发和SDN 网络层功能: 转发: 对于从某个端口 到来的分组转发到合适的 输出端口路由: 决定分组从源端 到目标端的路径 网络层 传统路由器的功能 每个路由器(Per Route)的控制平面 (传统) 每个路由器上都有实…

阿里云原生:如何熟悉一个系统

原文地址:https://mp.weixin.qq.com/s/J8eK-qRMkmHEQZ_dVts9aQ?poc_tokenHMA-_mWjfcDmGVW6hXX1xEDDvuJPE3pL9-8uSlyY 导读:本文总结了熟悉系统主要分三部分:业务学习、技术学习、实战。每部分会梳理一些在学习过程中需要解答的问题,这些问题…

linux下线程分离属性

linux下线程分离属性 一、线程的属性---分离属性二、线程属性设置2.1 线程创建前设置分离属性2.2 线程创建后设置分离属性 一、线程的属性—分离属性 什么是分离属性? 首先分离属性是线程的一个属性,有了分离属性的线程,不需要别的线程去接合…

STM32---DHT11温湿度传感器与BH1750FVI光照传感器(HAL库、含源码)

写在前面:本节我们学习使用两个常见的传感器模块,分别为DHT11温湿度传感器以及BH1750FVI光照传感器,这两种传感器在对于环境监测中具有十分重要的作用,因为其使用简单方便,所以经常被用于STM32的项目之中。今天将使用分享给大家&a…

相交链表:寻找链表的公共节点

目录 一、公共节点 二、题目 三、思路 四、代码 五、代码解析 1.计算长度 2.等长处理 3.判断 六、注意点 1.leetcode的尿性 2.仔细观察样例 3.经验总结 一、公共节点 链表不会像两直线相交一样,相交之后再分开。 由于单链表只有一个next指针&#xff0…

STM32 CAN的工作模式

STM32 CAN的工作模式 正常模式 正常模式下就是一个正常的CAN节点,可以向总线发送数据和接收数据。 静默模式 静默模式下,它自己的输出端的逻辑0数据会直接传输到它自己的输入端,逻辑1可以被发送到总线,所以它不能向总线发送显性…

FANUC机器人零点标定的基本步骤(出厂数据)

FANUC机器人零点标定的基本步骤(出厂数据) FANUC 零点数据存在问题的机器人通常会出现以下几种报警: (1)SRVO-062报警 - 脉冲编码器数据丢失,机器人完全不能动,具体消除方法可参考以下链接中的内容: FANUC机器人SRVO-062报警原因分析及处理对策 (2)SRVO-075报警 -…

北京中科富海低温科技有限公司确认出席2024第三届中国氢能国际峰会

会议背景 随着全球对清洁能源的迫切需求,氢能能源转型、工业应用、交通运输等方面具有广阔前景,氢能也成为应对气候变化的重要解决方案。根据德勤的报告显示,到2050年,绿色氢能将有1.4万亿美元市场。氢能产业的各环节的关键技术突…

Hive SQL必刷练习题:排列组合问题【通过join不等式】

排列组合问题【通过join不等式】 这种问题,就是数学的排列不等式,一个队伍只能和其余队伍比一次,不能重复 方法1:可以直接通过join,最后on是一个不等式【排列组合问题的解决方式】 方法2:也可以是提前多加…

Docker安装配置

1. 安装docker-ce sudo yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo yum -y install docker-ce sudo systemctl enable docker 2. 设置代理 参照:https://docs.docker.com/config/daemon/systemd/#httpht…

【Flask】Flask数据迁移操作

Flask数据迁移操作 前提条件 安装第三方包: # ORM pip install flask-sqlalchemy # 数据迁移 pip install flask-migrate # MySQL驱动 pip install pymysql # 安装失败,指定如下镜像源即可 # pip install flask-sqlalchemy https://pypi.tuna.tsinghu…

【大屏设计】如何进行软件系统网站大屏页面设计?不限于智慧城市、物联网、电商、园区领域

【大屏设计】如何进行软件系统网站大屏页面设计?不限于智慧城市、物联网、电商、园区领域 一、什么是网站大屏设计二、网站大屏设计原型素材三、网站大屏设计设计素材四、他山之石 一、什么是网站大屏设计 网站大屏设计是网站设计中至关重要的一部分,因…

Webman全局异常捕获处理

最近在使用webman这个框架做项目开发,涉及到需要统一处理异常捕获。由于官网给的并不详细,于是自己实现了一下全局异常处理类。 一、配置效果 例如:我要在项目中统一返回json 格式数据,并不想在业务层写try,catch逻辑。 或者在业务…

查看文件内容的指令:cat,tac,nl,more,less,head,tail,写入文件:echo

目录 cat 介绍 输入重定向 选项 -b -n -s tac 介绍 输入重定向 nl 介绍 示例 more 介绍 选项 less 介绍 搜索文本 选项 head 介绍 示例 选项 -n tail 介绍 示例 选项 echo 介绍 输出重定向 追加重定向 cat 介绍 将标准输入(键盘输入)的内容打…