P7——pytorch马铃薯病害识别

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.检查GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
warnings.filterwarnings("ignore")            
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

​​

2.查看数据

import os,PIL,random,pathlibdata_dir='data/PotatoPlants'
data_path=pathlib.Path(data_dir)
data_paths=list(data_path.glob('*'))
classNames=[str(path).split('\\')[2] for path in data_paths]
classNames

​​​​​​

3.划分数据集

train_transforms=transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.482,0.456,0.406],std=[0.229,0.224,0.225])
])
test_transforms=transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.482,0.456,0.406],std=[0.229,0.224,0.225])
])
total_data=datasets.ImageFolder("data/PotatoPlants/",transform=train_transforms)
total_datatotal_data.class_to_idxtrain_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_data,test_data=torch.utils.data.random_split(total_data,[train_size,test_size])
train_data,test_databatch_size=32
train_dl=torch.utils.data.DataLoader(train_data,batch_size,shuffle=True,num_workers=1)
test_dl=torch.utils.data.DataLoader(test_data,batch_size,shuffle=True,num_workers=1)for X,y in train_dl:print(X.shape)print(y.shape)break

​​​

4.创建模型

import torch.nn.functional as Fclass vgg16(nn.Module):def __init__(self):super(vgg16, self).__init__()# 卷积块1self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))# 卷积块2self.block2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))# 卷积块3self.block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))# 卷积块4self.block4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))# 卷积块5self.block5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))# 全连接网络层,用于分类self.classifier = nn.Sequential(nn.Linear(in_features=512*7*7, out_features=4096),nn.ReLU(),nn.Linear(in_features=4096, out_features=4096),nn.ReLU(),nn.Linear(in_features=4096, out_features=3))def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
model = vgg16().to(device)
modelimport torchsummary as summary
summary.summary(model, (3, 224, 224))


​​​​

5.动态调整学习率函数

#调用官方动态学习率接口
learning_rate = 1e-4
lambda1=lambda epoch:0.92**(epoch//4)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)

6.编译及训练模型

def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_loss,train_acc=0,0for X,y in dataloader:X,y =X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)#反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_loss+=loss.item()train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_lossdef test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)test_loss,test_acc=0,0with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)target_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_lossimport copy
optimizer  = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn    = nn.CrossEntropyLoss()
import copy
loss_fn=nn.CrossEntropyLoss()
epochs=5
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0
for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)#更新学习率scheduler.step()model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)if epoch_test_acc>=best_acc:best_acc=epoch_test_accbest_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)lr=optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
PATH='./best_model.pth'
torch.save(best_model.state_dict(),PATH)
print('Finished Training')

​​​​​​​​

7.结果可视化

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['figure.dpi']=100epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​​

8.预测本地图片

from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img)test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model=model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print('预测结果是:',pred_class)predict_one_image(image_path='data/PotatoPlants/Early_blight/1.JPG', model=model, transform=train_transforms, classes=classes)best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)epoch_test_acc, epoch_test_loss

​​​​​​​​​​

​​总结:

1.VGG-16

VGG-16(Visual Geometry Group-16)是由牛津大学视觉几何组(Visual Geometry Group)提出的一种深度卷积神经网络架构,用于图像分类和对象识别任务。VGG-16在2014年被提出,是VGG系列中的一种。VGG-16之所以备受关注,是因为它在ImageNet图像识别竞赛中取得了很好的成绩,展示了其在大规模图像识别任务中的有效性。

以下是VGG-16的主要特点:

  1. 深度:VGG-16由16个卷积层和3个全连接层组成,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
  2. 卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都接有ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量,从而降低了过拟合的风险。
  3. 池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
  4. 全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

VGG-16结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示;
  • 3个全连接层(Fully connected Layer),用classifier表示;
  • 5个池化层(Pool layer)。

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/495673.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2021-04-08 VSC++: 降序折半查找。

void 降序折半查找(int* a, int aa, int aaa) {//缘由https://bbs.csdn.net/topics/399166569int aaaa aaa / 2; bool k 0;if (a[0] aa){cout << 0, cout << ends << "查找&#xff1a;" << aa << endl;k 1;return;}else if (a[aa…

【蓝桥杯——物联网设计与开发】拓展模块3 - 温度传感器模块

目录 一、温度传感器模块 &#xff08;1&#xff09;资源介绍 &#x1f505;原理图 &#x1f505;STS30-DIS-B &#x1f319;引脚分配 &#x1f319;通信 &#x1f319;时钟拉伸&#xff08;Clock Stretching&#xff09; &#x1f319;单次触发模式 &#x1f319;温度数据转…

关于Edge浏览器的设置

这里记录几条个人比较习惯的使用浏览器方式的设置&#xff0c;主要是edge浏览器 1. 黑背景色 修改整个浏览器的背景色为黑色&#xff0c;而不是主题&#xff0c;只有边框颜色改变地址栏输入edge://flags/#enable-force-dark&#xff0c;将Default 改为 Enabled&#xff1b;如…

给bmp和png,设置BLENDFUNCTION的AlphaFormat不同参数的效果

BLENDFUNCTION是AlphaBlend用控制透明效果的重要参数。 选择一个32位的png图片&#xff0c;设置AlphaFormat 为 AC_SRC_ALPHA&#xff0c;效果如上图。 选择一个32位的png图片&#xff0c;设置AlphaFormat 为 0&#xff0c;效果如上图。 选择一个24位的bmp图片&#xff0c;设置…

Nginx界的天花板-Oracle 中间件OHS 11g服务器环境搭建

环境信息 服务器基本信息 如下表&#xff0c;本次安装总共使用2台服务器&#xff0c;具体信息如下&#xff1a; 服务器IP DNS F5配置 OHS1 172.xx.xx.xx ohs01.xxxxxx.com ohs.xxxxxx.com OHS2 172.xx.xx.xx ohs02.xxxxxx.com 服务器用户角色信息均为&#xff1a;…

智驾感知「大破局」!新一轮混战开启

随着智能驾驶搭载率的攀升&#xff0c;舱外传感器赛道迎来新变局。 一方面&#xff0c;从近几年智驾传感器的配置变化来看&#xff0c;摄像头的主导地位显而易见。 12月10-12日&#xff0c;由德赛西威总冠名的2024&#xff08;第八届&#xff09;高工智能汽车年会暨年度金球奖…

Hive其四,Hive的数据导出,案例展示,表类型介绍

目录 一、Hive的数据导出 1&#xff09;导出数据到本地目录 2&#xff09;导出到hdfs的目录下 3&#xff09;直接将结果导出到本地文件中 二、一个案例 三、表类型 1、表类型介绍 2、内部表和外部表转换 3、两种表的区别 4、练习 一、Hive的数据导出 数据导出的分类&…

基于WEB的房屋出租管理系统设计

摘 要 随着城市化程度的推进&#xff0c;越来越多的人涌入城市&#xff0c;同时也带来的旺盛的租房需求&#xff0c;传统的房屋出租管理依赖人 工记录的方式难以满足人们对房屋出租管理的需求。因此&#xff0c;本文根据房屋出租信息化的需求设计一款基于房屋出租 的管理系统。…

操作系统(26)数据一致性控制

前言 操作系统数据一致性控制是确保在计算机系统中&#xff0c;数据在不同的操作和处理过程中始终保持正确和完整的一种机制。 一、数据一致性的重要性 在当今数字化的时代&#xff0c;操作系统作为计算机系统的核心&#xff0c;负责管理和协调各种资源&#xff0c;以确保计算机…

C语言项目 天天酷跑(上篇)

前言 这里讲述这个天天酷跑是怎么实现的&#xff0c;我会在天天酷跑的下篇添加源代码&#xff0c;这里会讲述天天酷跑这个项目是如何实现的每一个思路&#xff0c;都是作者自己学习于别人的代码而创作的项目和思路&#xff0c;这个代码和网上有些许不一样&#xff0c;因为掺杂了…

Docker数据库的主从复制

有很多种方式,我这里使用的是docker镜像配置两个mysql容器做主从复制,一个做主服务器&#xff0c;另一个做从服务器&#xff0c;前提是有docker。 创建容器: 第一个容器: mkdir mysql cd mysql docker run -id \ -p 3307:3306 \ --namels \ -v /root/mysql/conf:/etc/mysql/co…

CH32V307VCT6---工程template创建

一、硬件&#xff1a;沁恒官网申请的CH32V307VCT6开发板 二、开发环境&#xff1a;Mounriver 三、最终效果 1.PB9连接LED1&#xff0c;使其闪烁 2.OLED屏幕显示&#xff1a;软件IIC&#xff0c;PB10----SDA&#xff0c;PB11---SCL 3.工程链接&#xff1a;CH32V307VCT6 lo…

mac中idea菜单工具栏没有git图标了

1.右击菜单工具栏 2.选中VCS&#xff0c;点击添加 3.搜索你要的工具&#xff0c;选中点击确定就添加了 4.回到上面一个界面&#xff0c;选中你要放到工具栏的工具&#xff0c;点击应用就好了 5.修改图标&#xff0c;快捷键或者右击选中编辑图标 6.选择你要的图标就好了

使用 Python 创建多栏 Word 文档 – 详解

目录 引言 一、工具与安装 二、Python 在 Word 中创建简单的多栏布局 三、Python 在 Word 文档的栏间添加分隔线 四、Python 从Word文档的指定位置开启多栏设置 五、Python 为多栏 Word 文档的各栏添加页码 引言 在文档设计中&#xff0c;排版不仅决定了内容的呈现方式&…

JZ31 栈的压入、弹出序列

题目来源&#xff1a;栈的压入、弹出序列_牛客题霸_牛客网 题目&#xff1a;如下 输入两个整数序列&#xff0c;第一个序列表示栈的压入顺序&#xff0c;请判断第二个序列是否可能为该栈的弹出顺序。假设压入栈的所有数字均不相等。例如序列1,2,3,4,5是某栈的压入顺序&#xf…

Spring源码_05_IOC容器启动细节

前面几章&#xff0c;大致讲了Spring的IOC容器的大致过程和原理&#xff0c;以及重要的容器和beanFactory的继承关系&#xff0c;为后续这些细节挖掘提供一点理解基础。掌握总体脉络是必要的&#xff0c;接下来的每一章都是从总体脉络中&#xff0c; 去研究之前没看的一些重要…

【Linux】进程间通信 -> 匿名管道命名管道

进程间通信的目的 数据传输&#xff1a;一个进程许需要将它的数据发送给另外一个进程。资源共享&#xff1a;多个进程之间共享同样的资源。通知事件&#xff1a;一个进程需要向另一个或一组进程发送消息&#xff0c;通知它们发生了某种事件&#xff08;如进程终止时要通知父进程…

大模型-使用Ollama+Dify在本地搭建一个专属于自己的聊天助手与知识库

大模型-使用OllamaDify在本地搭建一个专属于自己的知识库 1、本地安装Dify2、本地安装Ollama并解决跨越问题3、使用Dify搭建聊天助手4、使用Dify搭建本地知识库 1、本地安装Dify 参考往期博客&#xff1a;https://guoqingru.blog.csdn.net/article/details/144683767 2、本地…

黑盒测试/白盒测试知识总结

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 对于很多刚开始学习软件测试的小伙伴来说&#xff0c;如果能尽早将黑盒、白盒测试弄明白&#xff0c;掌握两种测试的结论和基本原理&#xff0c;将对自己后期的学习…

数据结构之线性表之顺序表

定义&#xff1a; 由n&#xff08;n>0&#xff09;个数据特性相同的元素构成的有限序列称为线性表 简单来说n个相同数据类型的数据组wsw合在一起的这么一个集合就是一个线性表 线性表包括顺序表和链表 1. 顺序表&#xff08;我们所有的代码实现都用函数来封装&#xff09…