基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sysimport torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdmdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def get_train_loader(image_path):train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = train_transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,shuffle=True, num_workers= 0)return train_loaderdef get_val_loader(image_path):val_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),transform = val_transform)val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,shuffle = False, num_workers = 0)return val_loaderdef train(train_loader,net):net.train()train_correct = 0.0train_loss = 0.0  # 初始化训练损失train_bar = tqdm(train_loader, file=sys.stdout)loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)optimizer = optim.Adam(net.parameters(), lr=0.001)for step, data in enumerate(train_bar):images, labels = dataimages, labels = images.to(device),labels.to(device)# 梯度清零optimizer.zero_grad()# 训练outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 反向传播loss.backward()# 更新权重optimizer.step()# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()train_correct += correcttrain_loss += loss.item()  # 累加损失值train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * train_loader.batch_size + len(images),total_samples=len(train_loader.dataset))train_correct = (100. * train_correct) / len(train_loader.dataset)train_loss /= len(train_loader)  # 计算平均损失值return train_correct, train_loss  # 返回训练正确率和平均损失值def val(val_loader,net):net.eval()val_correct = 0.0val_loss = 0.0  # 初始化验证损失loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)val_bar = tqdm(val_loader, file=sys.stdout)for step, data in enumerate(val_bar):images, labels = dataimages, labels = images.to(device), labels.to(device)with torch.no_grad():# 验证outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()val_correct += correctval_loss += loss.item()  # 累加损失值val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * val_loader.batch_size + len(images),total_samples=len(val_loader.dataset))val_correct = (100. * val_correct) / len(val_loader.dataset)val_loss /= len(val_loader)  # 计算平均损失值return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net

(2)resnet34 

def get_resnet34(class_num):net_name = "resnet34"net = torchvision.models.resnet34(pretrained=True)net.fc = Linear(in_features=512, out_features=class_num, bias=True)net = net.to(device)return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):net_name = "mobilenet_v2"net = torchvision.models.mobilenet_v2(pretrained=True)net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)net = net.to(device)return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader# 模型选择
def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# 训练主函数
if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#1 加载数据image_path = r"/kaggle/input/fruits3"train_loader = get_train_loader(image_path)val_loader = get_val_loader(image_path)#2 加载模型net_name,net = get_resnet34(class_num=5)#3 训练epochs = 5best_acc = 0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(epochs):train_acc,train_loss = train(train_loader, net)val_acc,val_loss = val(val_loader, net)train_losses.append(train_loss)val_losses.append(val_loss)train_accs.append(train_acc.item())val_accs.append(val_acc.item())if best_acc<val_acc:best_acc = val_acctorch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))# 画图save_path="/kaggle/working/" # 图片保存路径plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.4训练效果与模型文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337204.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用Spring Cache优化后端接口?

Spring Cache是Spring框架提供的一种缓存抽象,它可以很方便地集成到应用程序中,用于提高接口的性能和响应速度。使用Spring Cache可以避免重复执行耗时的方法,并且还可以提供一个统一的缓存管理机制,简化缓存的配置和管理。 本文将详细介绍如何使用Spring Cache来优化接口,…

【前端】Mac安装node14教程

在macOS上安装Node.js版本14.x的步骤如下&#xff1a; 打开终端。 使用Node Version Manager (nvm)安装Node.js。如果你还没有安装nvm&#xff0c;可以使用以下命令安装&#xff1a; curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.1/install.sh | bash 然后关…

基于NANO 9K 开发板加载PICORV32软核,并建立交叉编译环境

目录 0. 环境准备 1. 安装交叉编译器 2. 理解makefile工作机理 3. 熟悉示例程序的代码结构&#xff0c;理解软核代码的底层驱动原理 4. 熟悉烧录环节的工作机理&#xff0c; 建立下载环境 5. 编写例子blink&#xff0c; printf等&#xff0c; 加载运行 6. 后续任务 0.…

卷积网络迁移学习:实现思想与TensorFlow实践

摘要&#xff1a;迁移学习是一种利用已有知识来改善新任务学习性能的方法。 在深度学习中&#xff0c;迁移学习通过迁移卷积网络&#xff08;CNN&#xff09;的预训练权重&#xff0c;实现了在新领域或任务上的高效学习。 下面我将详细介绍迁移学习的概念、实现思想&#xff0c…

【成品设计】基于STM32单片机的饮水售卖机

基于STM32单片机的饮水售卖机 所需器件&#xff1a; STM32最小系统板。RFID&#xff1a;MFRC-522用于IC卡检测。OLED屏幕&#xff1a;用于显示当前水容量、系统状态等。水泵软管&#xff1a;用于抽水。水位传感器&#xff08;3个&#xff09;&#xff1a;用于分别标定&#x…

低代码赋能企业数字化转型:数百家软件公司的成功实践

本文转载于葡萄城公众号&#xff0c;原文链接&#xff1a;https://mp.weixin.qq.com/s/gN8Rq9TDmkMpCtNMMsBUXQ 导读 在当今的软件开发时代&#xff0c;以新技术助力企业数字化转型已经成为一个热门话题。如何快速适应技术变革&#xff0c;构建符合时代需求的技术能力和业务模…

【STM32F103】HC-SR04超声波测距

【STM32F103】HC-SR04超声波测距 一、HC-SR041、工作原理2、其他参数及时序图 二、代码编写思路三、HAL配置四、代码实现五、实验结果 前言 本次实验主要实现用stm32f103HC-SR04实现超声波测距&#xff0c;将测距数值通过串口上传到上位机串口助手 一、HC-SR04 1、工作原理 (…

【Unity知识点详解】Addressables的资源加载

今天来简单介绍一下Addressables&#xff0c;并介绍一下如何通过AssetName加载单个资源、如何通过Label加载多个资源、以及如何通过List<string>加载多个资源。由于Addressables的资源加载均为异步加载&#xff0c;所以今天给大家介绍如何使用StartCoroutine、如何使用As…

计算机算法中的数字表示法——浮点数

目录 1.前言2.浮点数的形式3.举例说明4.浮点数四则运算 微信公众号含更多FPGA相关源码&#xff1a; 1.前言 前面讲了定点表示法&#xff0c;定点表示法有一个主要的限制&#xff0c;那就是它不能有效地表示非常大或非常小的数&#xff0c;因为小数点的位置是固定的。为了解决这…

ios:文本框默认的copy、past改成中文复制粘贴

问题 ios 开发&#xff0c;对于输入框的一些默认文案展示&#xff0c;如复制粘贴是英文的&#xff0c;那么如何改为中文的呢 解决 按照路径找到这个文件 ios/项目/Info.plist&#xff0c;增加 <key>CFBundleAllowMixedLocalizations</key> <true/> <…

Echarts报警告Legend data should be same with series name or data name.

问题排查&#xff1a; 1. 确保 legend中的data中名字和series中每一项的name要匹配。 2. 仔细查看报警规律发现次数有在变化&#xff0c;因此找到代码中是动态修改legend,series的位置&#xff0c;检查一下这两个list的赋值逻辑。 果然&#xff0c;检查发现问题出现在了遍历里…

数据分析案例-在线食品订单数据可视化分析与建模分类

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

2.5Bump Mapping 凹凸映射

一、Bump Mapping 介绍 我们想要在屏幕上绘制物体的细节&#xff0c;从尺度上讲&#xff0c;一个物体的细节分为&#xff1a;宏观、中观、微观宏观尺度中其特征会覆盖多个像素&#xff0c;中观尺度只覆盖几个像素&#xff0c;微观尺度的特征就会小于一个像素宏观尺度是由顶点或…

在鲲鹏服务器搭建k8s高可用集群分享

高可用架构 本文采用kubeadm方式搭建k8s高可用集群&#xff0c;k8s高可用集群主要是对apiserver、etcd、controller-manager、scheduler做的高可用&#xff1b;高可用形式只要是为&#xff1a; 1. apiserver利用haproxykeepalived做的负载&#xff0c;多apiserver节点同时工作…

【主动均衡和被动均衡】

文章目录 1.被动均衡2.主动均衡1.被动均衡 被动均衡一般通过电阻放电的方式,对电压较高的电池进行放电,以热量形式释放电量,为其他电池争取更多充电时间。这样整个系统的电量受制于容量最少的电池。充电过程中,锂电池一般有一个充电上限保护电压值,当某一串电池达到此电压…

docker+vue云服务器打包镜像相关操作

dockervue云服务器打包镜像相关操作 容器化部署似乎成了当前一个非常主流的趋势&#xff0c;无论是前端还是后端&#xff0c;流行的操作就是给你一个镜像地址&#xff0c;让你自己去拉取镜像并运行镜像。这似乎是运维的工作&#xff0c;但是在没有专有运维的情况下&#xff0c…

Vue中,点击提交按钮,路由多了个问号

问题 当点击提交按钮是路由多了问号&#xff1a; http://localhost:8100/#/ 变为 http://localhost:8100/?#/原因 路由中出现问号通常是由于某些路径或参数处理不当造成的。在该情况下&#xff0c;是因为表单的默认行为导致的。提交表单时&#xff0c;如果没有阻止表单的默…

【CH32V305FBP6】调试入坑指南

1. 无法烧录程序 现象 MounRiver Studio WXH-LinkUtility 解决方法 前提&#xff1a;连接复位引脚 或者 2. 无法调试 main.c 与调试口冲突&#xff0c;注释后调试 // USART_Printf_Init(115200);

2024年5月31日 (周五) 叶子游戏新闻

《Granblue Fantasy: Relink》版本更新 新增可操控角色及功能世嘉股份有限公司现已公开《Granblue Fantasy: Relink》&#xff08;以下简称 Relink&#xff09;免费版本更新ver.1.3.1于5月31日&#xff08;周五&#xff09;上线的消息。该作是由Cygames Inc.&#xff08;下称Cy…

【class18】人工智能初步----语音识别(4)

【class17】 上节课&#xff0c;我们学习了: 语音端点检测的相关概念&#xff0c;并通过代码切分和保存了音频。 本节课&#xff0c;我们将学习这些知识点&#xff1a;1. 序列到序列模型2. 循环神经网络3. 调用短语音识别接口 知其然&#xff0c;知其所以然 在调用语…