深度学习入门:自建数据集完成花鸟二分类任务

自建数据集完成二分类任务(参考文章)

1 图片预处理

1 .1 统一图片格式

找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错

RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1

pytorch的torchvision.transforms模块提供了许多用于图片变换/增强的函数。

1.1.1 把图片不等比例压缩为固定大小
transforms.Resize((600,600)),
1.1.2 裁剪保留核心区

因为主体要识别的图像一般在中心位置,所以使用CenterCrop,这里设置为(400, 400)

transforms.CenterCrop((400,400)),
1.1.3 处理成统一数据类型

这里统一成torch.float64方便神经网络计算,也可以统一成其他比如uint32等类型

transforms.ConvertImageDtype(torch.float64),
1.1.4 归一化进一步缩小图片范围

对于图片来说0~255的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)

  • 我们可以使用[0.5,0.5,0.5]mean,std来把数据归一化至[-1,1]
  • 也可以手动计算出所有的图片mean,std来归一化至均值为0,标准差为1的正态分布,
  • 一些深度学习代码常常使用mean=[0.485, 0.456, 0.406] ,std=[0.229, 0.224, 0.225]的归一化数据,这是在ImageNet的几百万张图片数据计算得出的结果
  • BN等方法也具有很出色的归一化表现,我们也会使用到

Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)

我们这里使用简单的[0.5,0.5,0.5]归一化方法,更新cls_dataset,加入transform操作 ,作为图片裁剪的预处理。

transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

关于transforms的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。

1.2 数据增强

当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms中的函数 ,也可以混合使用来根据已有数据创造新数据

        self.data_enhancement = transforms.Compose([transforms.RandomHorizontalFlip(p=1),transforms.RandomRotation(30)])

2 创建自制数据集

2.1 以Dataset类接口为模版

class cls_dataset(Dataset):def __init__(self) -> None:# initializationdef __getitem__(self, index):# return data,label in set def __len__(self):# return the length of the dataset

2.2 创建set

2.2.1定义两个空列表data_list和target_list
2.2.2遍历文件夹
2.2.3读取图片对象,对每一个图片对象预处理后,分别将图片对象和对应的标签加入data_list和target_list中
2.2.4将data_list和target_list加入h5df_ile中
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_imagetrain_pic_path = 'test-set'
test_pic_path = 'training-set'def create_h5_file(file_name):all_type = ['flower', 'bird']h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空#图片统一化处理transform = transforms.Compose([transforms.Resize((600, 600)),transforms.CenterCrop((400, 400)),transforms.ConvertImageDtype(torch.float64),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])#数据增强data_list = []   #建立一个保存图片张量的空列表target_list = [] #建立一个保存图片标签的空列表#遍历文件夹建立数据集'''文件夹组成| —— train|   | —— flower|   |   | —— 图片1|   | —— bird|   | —— | —— 图片2| —— test|   | —— flower|   | —— bird'''dataset_kind = file_name.split('.')[0]#先判断缺失的文件是训练集还是测试集if dataset_kind == 'train':pic_file_name = train_pic_pathelse:pic_file_name = test_pic_path#再循环遍历文件夹for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):target = file_name_dir.split('/')[-1]if target in all_type:for file in files:pic = read_image(os.path.join(file_name_dir, file))  #以张量形式读取图片对象pic = transform(pic)    #预处理图片pic = np.array(pic).astype(np.float64)data_list.append(pic)   #将pic对象添加到列表里target_list.append(target.encode()) #将target编码后添加到列表里h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)h5df_file.close()class h5py_dataset(Dataset):def __init__(self, file_name) -> None:super().__init__()self.file_name = file_name    #指向文件的路径名#如果file_name指向的h5文件不存在,就新建一个if not os.path.exists(file_name):create_h5_file(file_name)def __getitem__(self, index):with h5py.File(self.file_name, 'r') as f:if f['target'][index].decode() == 'bird':   #如果在f文件的target列表中查找到index下标对应的标签是birdtarget = torch.tensor(0)else:target = torch.tensor(1)return f['image'][index], targetdef __len__(self):with h5py.File(self.file_name, 'r') as f:return len(f['target'])def h5py_loader():train_file = 'train.hdf5'test_file = 'test.hdf5'train_dataset = h5py_dataset(train_file)test_dataset = h5py_dataset(test_file)train_data_loader = DataLoader(train_dataset, batch_size=4)test_data_loader = DataLoader(test_dataset, batch_size=4)return train_data_loader, test_data_loader

2.3 创建loader

实例化set对象后利用torch.utils.data.DataLoader

3 搭建网络

3.1 网络结构

在这里插入图片描述

3.2 参数计算

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1

参考文章

3.3 不成文规定

池化参数一般就是(2, 2)

中间的channel数量都是自己设定的,二的次方就行

kernelsize一般3或者5之类的

4 训练

加深对前面数据集组成理解

    for _, data in enumerate(train_loader):if isinstance(data, list):image = data[0].type(torch.FloatTensor).to(device)target = data[1].to(device)elif isinstance(data, dict):image = data['image'].type(torch.FloatTensor).to(device)target = data['target'].to(device)else:print(type(data))raise TypeError

for 循环中data的组成来源于构建set时,

    h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)

写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合

在这里插入图片描述

5 测试

6 保存模型

改进

投影概率放到网络里面

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/200085.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VRTK】【VR开发】【Unity】7-配置交互能力和向量追踪

【前情提要】 目前为止,我们虽然设定了手模型和动画,还能够正确根据输入触发动作,不过还未能与任何物体互动。要互动,需要给手部设定相应的Interactor能力。 【配置Interactor的抓取功能】 在Hierarchy中选中[VRTK_CAMERA_RIGS_SETUP] ➤ Camera Rigs, Tracked Alias ➤ …

VBA技术资料MF85:将工作簿批量另存为PDF文件

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…

UE4 基础篇十四:自定义插件

文末有视频地址和git地址 一、概念 虚幻里插件都是用C++写的,C++包括.h文件和.cpp文件,.h头文件通常包含函数类型和函数声明,cpp文件包含这些类型和函数的实现, 你为项目编写的所有代码文件都必须位于模块中,模块就是硬盘里的一个文件夹,包含名为“Build.cs”的C#文件…

BGP笔记实验

IGP(Interior Gateway Protocol)——内部网关协议 OSPF RIP IS-IS IGRP EIGRP EGP(External Gateway Protocol)——外部网关协议 EGP BGP——边界网关协议 AS——自治系统 由单一组织or机构独立维护的网络设备&网络资源的集合 网络范围太大 自治 AS号 为了区分不同…

API网关那些事【架构新知系列】

目前随着云原生ServiceMesh和微服务架构的不断演进,网关领域新产品不断出现,各种网关使用的技术,功能和应用领域也不断扩展,在各有所长的前提下也有很多功能重合,网上各种技术PR文章,评测资料和网关落地实践…

2024年山东省职业院校技能大赛中职组“网络安全”赛项竞赛试题-A

2024年山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题-A 一、竞赛时间 总计:360分钟 二、竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 A、B模块 A-1 登录安全加固 180分钟 200分 A-2 本地安全策略设置 A-3 流量完整性保护 A-4 …

Ubuntu18.04安装LeGO-LOAM保姆级教程

系统环境:Ubuntu18.04.6 LTS 1.LeGO-LOAM的安装前要求: 1.1 ROS安装:参考我的另一篇博客Ubuntu18.04安装ROS-melodic保姆级教程_灬杨三岁灬的博客-CSDN博客文章浏览阅读168次。Ubuntu18.04安装ROS-melodic保姆级教程https://blog.csdn.net/…

海外IP代理科普——API代理是什么?怎么用?

随着互联网的不断发展,越来越多的企业开始使用API(应用程序接口)来实现数据的共享和交流。而在API使用中,海外代理IP也逐渐普及。那么,什么是API代理IP呢?它有什么作用?API接口有何用处&#xf…

Ubuntu22.04 交叉编译GCC13.2.0 for Rv1126

一、安装Ubuntu22.04 sudo apt install vim net-tools openssh-server 二、安装必要项 sudo apt update sudo apt upgrade sudo apt install build-essential gawk git texinfo bison flex 三、下载必备软件包 1.glibc https://ftp.gnu.org/gnu/glibc/glibc-2.38.tar.gz…

vs code git问题:文件明明已加入忽略文件中,还是出现

vs code git问题:文件明明已加入忽略文件中,还是出现 原因: 因为之前这些文件都已经提交过,线上GIT已经存在,已存在就不能忽略, 解决办法: 先要删除这些文件提交上去,然后把这些文…

echarts 中如何添加左右滚动条 数据如何进行堆叠如何配置那些数据使用那个数据轴

左右滚动条的效果 此项的具体配置可参考 https://echarts.apache.org/zh/option.html#dataZoom-inside.moveOnMouseWheel dataZoom: [{id: dataZoomX,type: inside,// start: 0,// end: this.xAxis.length > 5 ? 10 : 100,startValue: this.xAxis.length > 5 ? 5 : 0,/…

RoCE、IB和TCP等网络的基本知识及差异对比

目前有三种RDMA网络,分别是Infiniband、RoCE(RDMA over Converged Ethernet)、iWARP。 其中,Infiniband是一种专为RDMA设计的网络,从硬件级别保证可靠传输 ,技术先进,但是成本高昂。 而RoCE 和 iWARP都是基于以太网的…

004 OpenCV akaze特征点检测匹配

目录 一、环境 二、akaze特征点算法 2.1、基本原理 2.2、实现过程 2.3、实际应用 2.4、优点与不足 三、代码 3.1、数据准备 3.2、完整代码 一、环境 本文使用环境为: Windows10Python 3.9.17opencv-python 4.8.0.74 二、akaze特征点算法 特征点检测算法…

使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的 语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧 2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果…

ElementUI及ElementUI Plus Axure RP高保真交互元件库及模板库

基于ElementUI2.0及ElementUI Plus3.0二次创作的ElementUI 元件库。2个版本的原型图内容会有所不同,ElementUI Plus3.0的交互更加丰富和高级。你可以同时使用这两个版本。 不仅包含Element UI 2.0版,还包含Element Plus 3版本。Element 2版支持Axure 8&…

接口自动化项目落地之HTTPBin网站

原文:https://www.cnblogs.com/df888/p/16011061.html 接口自动化项目落地系列 找个开源网站或开源项目,用tep实现整套pytest接口自动化项目落地,归档到电子书,作为tep完整教程的项目篇一部分。自从tep完整教程发布以后&#…

基于单片机的公共场所马桶设计(论文+源码)

1.系统设计 本课题为公共场所的马桶设计,其整个系统架构如图2.1所示,其采用STC89C52单片机为核心控制器,结合HC-SR04人体检测模块,压力传感器,LCD1602液晶,蜂鸣器,L298驱动电路等构成整个系统&…

RedisInsight——redis的桌面UI工具使用实践

下载 官网下载安装。下载地址在这里 填个邮箱地址就可以下载了。 安装使用。 安装成功后开始使用。 1. 你可以add一个地址。或者登录redis cloud 去auto-discover 2 . 新增你的redis库地址。注意index的取值 3。现在可以登录到redis了。看看结果 这是现在 在服务器上执行…

C#核心笔记——(二)C#语言基础

一、C#程序 1.1 基础程序 using System; //引入命名空间namespace CsharpTest //将以下类定义在CsharpTest命名空间中 {internal class TestProgram //定义TestProgram类{public void Test() { }//定义Test方法} }方法是C#中的诸多种类的函数之一。另一种函数*,还…

BLIP-2:冻结现有视觉模型和大语言模型的预训练模型

Li J, Li D, Savarese S, et al. Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models[J]. arXiv preprint arXiv:2301.12597, 2023. BLIP-2,是 BLIP 系列的第二篇,同样出自 Salesforce 公司&…