pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/374620.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker安装遇到问题:curl: (7) Failed to connect to download.docker.com port 443: 拒绝连接

问题描述 首先&#xff0c;完全按照Docker官方文档进行安装&#xff1a; Install Docker Engine on Ubuntu | Docker Docs 在第1步&#xff1a;Set up Dockers apt repository&#xff0c;执行如下指令&#xff1a; sudo curl -fsSL https://download.docker.com/linux/ubu…

MybatisPlus 使用教程

MyBatisPlus使用教程 文章目录 MyBatisPlus使用教程1、使用方式1.1 引入依赖1.2 构建mapper接口 2、常用注解2.1 TableName2.2 TableId2.3 TableField MyBatisPlus顾名思义便是对MyBatis的加强版&#xff0c;但两者本身并不冲突(只做增强不做改变)&#xff1a; 引入它并不会对原…

[数据集][目标检测]护目镜检测数据集VOC+YOLO格式888张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;888 标注数量(xml文件个数)&#xff1a;888 标注数量(txt文件个数)&#xff1a;888 标注类别…

C语言基本概念

C语言是什么&#xff1f; 1.人与人之间 自然语言 2.人与计算机之间 计算机语言 例如C、Java、Go、Python 在计算机语言中 1.解释型语言&#xff1a;Python 2.编译型语言&#xff1a;C/C 编译和链接 C语言源代码都是文本文件.c&#xff0c;必须通过编译器的编译和链接器的…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第一篇 嵌入式Linux入门篇-第十八章 Linux编写第一个自己的命令

i.MX8MM处理器采用了先进的14LPCFinFET工艺&#xff0c;提供更快的速度和更高的电源效率;四核Cortex-A53&#xff0c;单核Cortex-M4&#xff0c;多达五个内核 &#xff0c;主频高达1.8GHz&#xff0c;2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

基于Python的哔哩哔哩数据分析系统设计实现过程,技术使用flask、MySQL、echarts,前端使用Layui

背景和意义 随着互联网和数字媒体行业的快速发展&#xff0c;视频网站作为重要的内容传播平台之一&#xff0c;用户量和内容丰富度呈现爆发式增长。本研究旨在设计并实现一种基于Python的哔哩哔哩数据分析系统&#xff0c;采用Flask框架、MySQL数据库以及echarts数据可视化技术…

昇思MindSpore学习入门-参数初始化

使用内置参数初始化 MindSpore提供了多种网络参数初始化的方式&#xff0c;并在部分算子中封装了参数初始化的功能。本节以Conv2d为例&#xff0c;分别介绍如何使用Initializer子类&#xff0c;字符串进行参数初始化。 Initializer初始化 Initializer是MindSpore内置的参数初…

硬件开发工具Arduino IDE

招聘信息共享社群 关联上篇文章乐鑫ESPRESSIF芯片开发简介 Arduino IDE&#xff08;集成开发环境&#xff09;是为Arduino硬件开发而设计的一款软件&#xff0c;它提供了一个易于使用的图形界面&#xff0c;允许用户编写、编辑、编译和上传代码到Arduino开发板。Arduino IDE的…

【前端】包管理器:npm、Yarn 和 pnpm 的全面比较

前端开发中的包管理器&#xff1a;npm、Yarn 和 pnpm 的全面比较 在现代前端开发中&#xff0c;包管理器是开发者必不可少的工具。它们不仅能帮我们管理项目的依赖&#xff0c;还能极大地提高开发效率。本文将详细介绍三种主流的前端包管理器&#xff1a;npm、Yarn 和 pnpm&am…

六、数据可视化—Echars(爬虫及数据可视化)

六、数据可视化—Echars&#xff08;爬虫及数据可视化&#xff09; Echarts应用 Echarts Echarts官网&#xff0c;很多图表等都是我们可以 https://echarts.apache.org/zh/index.html 是百度自己做的图表&#xff0c;后来用的人越来越多&#xff0c;捐给了orange组织&#xf…

相机光学(三十)——N5-N7-N8中性灰

GTI可提供N5/N7/N8中性灰涂料&#xff0c;用于不同的看色环境&#xff0c;N5/N7/N8代表深中浅不同的灰色程度&#xff0c;在成像、工业、印刷行业中&#xff0c;分别对周围观察环境有一定的要求&#xff0c;也出台了相应的标准文件&#xff0c;客户可以根据实际使用环境进行选择…

FiddlerScript Rules修改-更改发包中的cookie

直接在fiddler script editor中增加如下处理代码即可 推荐文档oSession -- 参数说明 测试笔记 看云

树莓派4B_OpenCv学习笔记19:OpenCV舵机云台物体追踪

今日继续学习树莓派4B 4G&#xff1a;&#xff08;Raspberry Pi&#xff0c;简称RPi或RasPi&#xff09; 本人所用树莓派4B 装载的系统与版本如下: 版本可用命令 (lsb_release -a) 查询: Opencv 版本是4.5.1&#xff1a; Python 版本3.7.3&#xff1a; ​​ 今日学习&#xff1…

RAG 工业落地方案框架(Qanything、RAGFlow、FastGPT、智谱RAG)细节比对!CVPR自动驾驶最in挑战赛赛道,全球冠军被算力选手夺走了

RAG 工业落地方案框架&#xff08;Qanything、RAGFlow、FastGPT、智谱RAG&#xff09;细节比对&#xff01;CVPR自动驾驶最in挑战赛赛道&#xff0c;全球冠军被算力选手夺走了。 本文详细比较了四种 RAG 工业落地方案 ——Qanything、RAGFlow、FastGPT 和智谱 RAG&#xff0c;重…

不仅是输出信息,console.log 也能玩出花

console.log 是 JavaScript 中一个常用的函数&#xff0c;用于向控制台输出信息。 console.log 虽然主要用于调试目的&#xff0c;但也包含了一些有趣的用法&#xff0c; console.log 不仅能输出文本&#xff0c;还能以更丰富的方式展示信息。 比如我们打开 B 站&#xff0c;然…

79. UE5 RPG 创建技能冷却和消耗

在这一篇里面&#xff0c;我们接着优化技能&#xff0c;现在角色添加的主动技能能够同步到ui上面。我们在这一篇文章里面&#xff0c;完善技能的消耗&#xff08;释放技能减少蓝量&#xff09;和冷却机制。 我们可以看到&#xff0c;在技能类默认值这里&#xff0c;可以设置它的…

【YashanDB知识库】YashanDB 开机自启

【问题分类】 YashanDB 开机自启 【关键字】 开机自启&#xff0c;依赖包 【问题描述】 数据库所在服务器重启后只拉起monit、yasom、yasom进程&#xff0c;缺少yasdb进程&#xff1a; 【问题原因分析】 数据库安装的时候未启动守护进程 【解决 / 规避方法】 进入数据库之前…

问题清除指南|Dell OptiPlex 7070 升级 win11 开启 TPM 2.0 教程

前言&#xff1a;最近想把实验室台式机的系统从 Windows 10 升级到 Windows 11&#xff0c;遇到一点小问题&#xff0c;在此记录一下解决办法。 ⚠️ 注&#xff1a;本教程仅在 Dell OptiPlex 7070 台式机系统中测试有效&#xff0c;并不保证其余型号机器适用此教程。 参考链接…

计算机网络体系结构解析

OSI参考模型 与 TCP/IP模型 如图所示 TCP/IP模型有几层 应用层&#xff1a;只需要专注于为用户提供应用功能 HTTP、SMTP、Telnet等&#xff0c;工作在操作系统中的用户态&#xff0c;传输层及以下工作在内核态传输层&#xff1a;为应用层提供网络支持&#xff08;TCP、UDP传…

谷粒商城实战-25-分布式组件-SpringCloud Alibaba-Nacos配置中心-加载多配置集

文章目录 一&#xff0c;拆分配置集二&#xff0c;配置文件中配置多配置集1&#xff0c;引用多配置集2&#xff0c;验证 三&#xff0c;多配置集总结1&#xff0c;使用场景2&#xff0c;优先级 这一节介绍如何加载多个配置集。 大多数情况下&#xff0c;我们把配置全部放在一个…