【深度学习实验】卷积神经网络(六):卷积神经网络模型(VGG)训练、评价

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

b. CIFAR10Dataset

2. 构建模型(FeedForward)

3.整合训练、评估、预测过程(Runner)

4. __main__

代码整合


一、实验介绍

        本实验实现了一个简化版VGG网络,并基于此完成图像分类任务。(包括模型训练、评价)
       

        VGG网络是深度卷积神经网络中的经典模型之一,由牛津大学计算机视觉组(Visual Geometry Group)提出。它在2014年的ImageNet图像分类挑战中取得了优异的成绩(分类任务第二,定位任务第一),被广泛应用于图像分类、目标检测和图像生成等任务。

        VGG网络的主要特点是使用了非常小的卷积核尺寸(通常为3x3)和更深的网络结构。该网络通过多个卷积层和池化层堆叠在一起,逐渐增加网络的深度,从而提取图像的多层次特征表示。VGG网络的基本构建块是由连续的卷积层组成,每个卷积层后面跟着一个ReLU激活函数。在每个卷积块的末尾,都会添加一个最大池化层来减小特征图的尺寸。VGG网络的这种简单而有效的结构使得它易于理解和实现,并且在不同的任务上具有很好的泛化性能。

        VGG网络有几个不同的变体,如VGG11、VGG13、VGG16和VGG19,它们的数字代表网络的层数。这些变体在网络深度和参数数量上有所区别,较深的网络通常具有更强大的表示能力,但也更加复杂。

二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、计算机视觉和模式识别等领域。它的设计灵感来自于生物学中视觉皮层的工作原理。

        卷积神经网络通过多个卷积层、池化层全连接层组成。

  • 卷积层主要用于提取图像的局部特征,通过卷积操作和激活函数的处理,可以学习到图像的特征表示。
  • 池化层则用于降低特征图的维度,减少参数数量,同时保留主要的特征信息。
  • 全连接层则用于将提取到的特征映射到不同类别的概率上,进行分类或回归任务。

        卷积神经网络在图像处理方面具有很强的优势,它能够自动学习到具有层次结构的特征表示,并且对平移、缩放和旋转等图像变换具有一定的不变性。这些特点使得卷积神经网络成为图像分类、目标检测、语义分割等任务的首选模型。除了图像处理,卷积神经网络也可以应用于其他领域,如自然语言处理和时间序列分析。通过将文本或时间序列数据转换成二维形式,可以利用卷积神经网络进行相关任务的处理。

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

        从CSV文件中读取标签信息并返回一个标签字典。

def read_csv_labels(fname):"""读取fname来给标签字典返回一个文件名"""with open(fname, 'r') as f:# 跳过文件头行(列名)lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name, label) for name, label in tokens))
  •  使用open函数打开指定文件名的CSV文件,并将文件对象赋值给变量f。这里使用'r'参数以只读模式打开文件。

  • 使用文件对象的readlines()方法读取文件的所有行,并将结果存储在名为lines的列表中。通过切片操作[1:],跳过了文件的第一行(列名),将剩余的行存储在lines列表中。

  • 列表推导式(list comprehension):对lines列表中的每一行进行处理。对于每一行,使用rstrip()方法去除行末尾的换行符,并使用split(',')方法将行按逗号分割为多个标记。最终,将所有行的标记组成的子列表存储在tokens列表中。

  • 使用字典推导式(dictionary comprehension)将tokens列表中的子列表转换为字典。对于tokens中的每个子列表,将子列表的第一个元素作为键(name),第二个元素作为值(label),最终返回一个包含这些键值对的字典。

b. CIFAR10Dataset

class CIFAR10Dataset(Dataset):def __init__(self, folder_path, fname):self.labels = read_csv_labels(os.path.join(folder_path, fname))self.folder_path = os.path.join(folder_path, 'train')def __len__(self):return len(self.labels)def __getitem__(self, idx):img = read_image(self.folder_path + '/' + str(idx + 1) + '.png')label = self.labels[str(idx + 1)]return img, torch.tensor(int(label))
  • 构造函数:

    • 接受两个参数

      • folder_path表示数据集所在的文件夹路径

      • fname表示包含标签信息的文件名。

    • 调用read_csv_labels函数,传递folder_pathfname作为参数,以读取CSV文件中的标签信息,并将返回的标签字典存储在self.labels变量中。

    • 通过拼接folder_path和字符串'train'来构建数据集的文件夹路径,将结果存储在self.folder_path变量中。

  • def __len__(self)

    • 这是CIFAR10Dataset类的方法,用于返回数据集的长度,即样本的数量。

  • def __getitem__(self, idx): 这是CIFAR10Dataset类的方法,用于根据给定的索引idx获取数据集中的一个样本。它首先根据索引idx构建图像文件的路径,并调用read_image函数来读取图像数据,将结果存储在img变量中。然后,它通过将索引转换为字符串,并使用该字符串作为键来从self.labels字典中获取相应的标签,将结果存储在label变量中。最后,它返回一个元组,包含图像数据和经过torch.tensor转换的标签。

2. 构建模型(FeedForward)

        参考前文:

【深度学习实验】卷积神经网络(五):深度卷积神经网络经典模型——VGG网络(卷积层、池化层、全连接层)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133350927?spm=1001.2014.3001.5501

3.整合训练、评估、预测过程(Runner)

        参考前文:

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133219448?spm=1001.2014.3001.5501

        (略有改动:)

class Runner(object):def __init__(self, model, optimizer, loss_fn, metric=None):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0# 模型训练阶段def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path','best_model.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif step%log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if eval_steps != 0 :if (epoch+1) % eval_steps ==  0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch+1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式   self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss/len(train_loader)).item()self.train_epoch_losses.append((global_step,train_loss))self.save_model(f'{save_path}.pth')   print('[Train] Train done')# 模型评价阶段def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step',-1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += loss self.metric.update(logits, y)dev_loss = (total_loss/len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss# 模型预测阶段,def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits# 保存模型的参数def save_model(self, save_path):torch.save(self.model.state_dict(), save_path)# 读取模型的参数def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

4. __main__

 batch_size = 20# 构建训练集train_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')train_iter = DataLoader(train_data, batch_size=batch_size)# 构建测试集num_classes = 10# 定义模型model = VGG_S(num_classes)# 定义损失函数loss_fn = F.cross_entropy# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.1)runner = Runner(model, optimizer, loss_fn, metric=None)runner.train(train_iter, num_epochs=10, save_path='chapter_5')

本文有待进一步完善……

代码整合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/144600.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python监控ES索引数量变化

文章目录 1, datafram根据相同的key聚合2, 数据合并:获取采集10,20,30分钟es索引数据脚本测试验证 1, datafram根据相同的key聚合 # 创建df1 > json {key:A, value:1 } {key:B, value:2 } data1 {key: [A, B], value: [1, 2]} df1 pd.DataFrame(data1)# 创建d…

rhel8 网络操作学习

一、查询dns服务器地址汇总 1.查询dns服务器地址: (1)方法一:执行命令 cat /etc/resolv.conf 执行结果如下: nameserver后面就是dns服务器的ip地址。 (2)方法2:查看/etc/syscon…

蓝海彤翔亮相2023新疆网络文化节重点项目“新疆动漫节”

9月22日上午,2023新疆网络文化节重点项目“新疆动漫节”(以下简称“2023新疆动漫节”)在克拉玛依科学技术馆隆重开幕,蓝海彤翔作为国内知名的文化科技产业集团应邀参与此次活动,并在美好新疆e起向未来动漫展映区设置展…

【C++数据结构】二叉树搜索树【完整版】

目录 一、二叉搜索树的定义 二、二叉搜索树的实现: 1、树节点的创建--BSTreeNode 2、二叉搜索树的基本框架--BSTree 3、插入节点--Insert 4、中序遍历--InOrder 5、 查找--Find 6、 删除--erase 完整代码: 三、二叉搜索树的应用 1、key的模型 &a…

工具学习--easyexcel-3.x 使用--写入基本使用,自定义转换--动态表头以及宽设置-

写在前面: easyexcel是alibaba开发简单导出未excel的工具。使用的情况还是比较多的。 文章目录 依赖导入写Excel快速入门对象设置ExcelProperty设置列属性ExcelIgnore 忽视列宽、行高格式转换时间格式化数字格式化自定义格式化 合并单元格其他更加个性化需求动态表…

【Java 进阶篇】MySQL多表关系详解

MySQL是一种常用的关系型数据库管理系统,它允许我们创建多个表格,并通过各种方式将这些表格联系在一起。在实际的数据库设计和应用中,多表关系是非常常见的,它能够更好地组织和管理数据,实现数据的复杂查询和分析。本文…

react+IntersectionObserver实现页面丝滑帧动画

实现效果: 加入帧动画前: 普通的静态页面 加入帧动画后: 可以看到,加入帧动画后,页面效果还是比较丝滑的。 技术实现 加入animation动画类 先用 **scss **定义三种动画类: .withAnimation {.fade1 {ani…

JavaScript Web APIs第二天笔记

Web APIs - 第2天 学会通过为DOM注册事件来实现可交互的网页特效。 能够判断函数运行的环境并确字 this 所指代的对象理解事件的作用,知道应用事件的 3 个步骤 学习会为 DOM 注册事件,实现简单可交互的网页特交。 事件 事件是编程语言中的术语&#xff…

Word | 简单可操作的快捷公式编号、右对齐和引用方法

1. 问题描述 在理工科论文的写作中,涉及到大量的公式输入,我们希望能够按照章节为公式进行编号,并且实现公式居中,编号右对齐的效果。网上有各种各样的方法来实现,操作繁琐和简单的混在一起,让没有接触过公…

Visual Studio 代码显示空格等空白符

1.VS2010: 快捷键:CtrlR,W 2.VS2017、VS2019、VS2022: 工具 -> 选项 -> 文本编辑器 -> 显示 -> 勾选查看空白

解决webpack报错:You forgot to add ‘mini-css-extract-plugin‘ plugin

现象: 原因: webpack5.72跟mini-css-extract-plugin有兼容性问题 解决办法:把 new MiniCssExtractPlugin()放在webpack配置文件中plugins数组的第一项: plugins: [ // 此处解决报错:You forgot to add mini-css-extra…

Java项目-文件搜索工具

目录 项目背景 项目效果 SQLite的下载安装 使用JDBC操作SQLite 第三方库pinyin4j pinyin4j的具体使用 封装pinyin4j 数据库的设计 创建实体类 实现DBUtil 封装FileDao 设计scan方法 多线程扫描 周期性扫描 控制台版本的客户端 图形化界面 设计图形化界面 项目…

最新AI创作系统源码ChatGPT源码+附详细搭建部署教程+AI绘画系统+支持国内AI提问模型

一、AI系统介绍 SparkAi创作系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧&am…

二十二,加上各种贴图

使用pbr的各种贴图,albedo,金属度,ao,法线,粗糙度,可以更好的控制各个片元 1,先加上纹理坐标 texCoords->push_back(osg::Vec2(xSegment, ySegment)); geom->setVertexAttribArray(3, texCoords, osg::Array::BI…

WebPack-打包工具

从图中我们可以看出,Webpack 可以将多种静态资源 js、css、less 转换成一个静态文件,减少了页面的请求. 下面举个例子 : main.js 我们只命名导出一个变量 export const name"老六"index.js import { name } from "./tset/…

聊聊并发编程——Condition

目录 一.synchronized wait/notify/notifyAll 线程通信 二.Lock Condition 实现线程通信 三.Condition实现通信分析 四.JUC工具类的示例 一.synchronized wait/notify/notifyAll 线程通信 关于线程间的通信,简单举例下: 1.创建ThreadA传入共享…

Vue之ElementUI实现登陆及注册

目录 ​编辑 前言 一、ElementUI简介 1. 什么是ElementUI 2. 使用ElementUI的优势 3. ElementUI的应用场景 二、登陆注册前端界面开发 1. 修改端口号 2. 下载ElementUI所需的js依赖 2.1 添加Element-UI模块 2.2 导入Element-UI模块 2.3 测试Element-UI是否能用 3.编…

【VUE复习·9】v-for 基础用法(循环渲染也叫列表渲染)

总览 1.v-for 都能循环什么 2.用法 一、v-for 都能遍历什么 能循环的东西包括:数组、对象、字符串(和java里面的3个引用数据类型一样)、纯粹循环数量(少用) 二、用法 1.用法1:简单循环(遍历…

Activiz 9.2 for Linux Crack

Activiz 9.2 在 C#、.Net 和 Unity 软件中为您的 3D 内容释放可视化工具包的强大功能。 ActiViz 允许您轻松地将 3D 可视化集成到您的应用程序中。 ActiViz 功能 用 C# 封装的 3D 可视化软件系统 允许在 .NET 环境中快速开发可投入生产的交互式3D 应用程序 支持窗口演示基础 (…

SI3262:国产NFC+MCU+防水触摸按键三合一SoC芯片

目录 SI3262简介特点结构框图芯片特性 SI3262简介 Si3262是高度集成ACD低功耗MCUNFC15通道防水触摸按键的SoC芯片。 其MCU模块具有低功耗、Low Pin Count、宽电压工作范围,集成了13/14/15/16位精度的ADC、LVD、UART、SPI、I2C、TIMER、WUP、IWDG、RTC、TSC等丰富的…