PyTorch深度学习网络(二:CNN)

卷积神经网络(CNN)是一种专门用于处理具有类似网格结构数据的深度学习模型,例如图像(2D网格的像素)和时间序列数据(1D网格的信号强度)。CNN在图像识别、图像分类、物体检测、语音识别等领域有着广泛的应用。

CNN的核心特点包括局部连接和权值共享:

  1. 局部连接意味着每个神经元只与输入数据的一个局部区域相连,这大大减少了参数的数量,提高了计算效率;
  2. 权值共享是指在卷积层中,相同的卷积核被用于整个输入数据,这不仅进一步减少了参数数量,还使得模型具有平移不变性,即无论物体出现在图像的哪个位置,都能被识别出来。

CNN的基本结构包括输入层、卷积层、激活函数、池化层、全连接层和输出层:

  1. 输入层接收原始图像数据;
  2. 卷积层通过卷积操作提取图像的特征;
  3. 激活函数引入非线性,增强模型的表达能力;
  4. 池化层通过下采样减少数据量,同时保留重要特征,增强模型的鲁棒性,此外多层卷积和池化层的堆叠使得模型能够逐层提取更高层次的特征;
  5. 全连接层将前面各层提取的特征综合起来,用于最终的分类或回归任务。

这种结构使得CNN特别适合处理图像数据,能够自动学习图像中的复杂特征,实现高效准确的图像识别。

本文展示了几种CNN网络结构在图像或文本分类中的应用,包含以下内容:

  1. LeNet的搭建和应用
  2. 微调预训练的VGG16网络
  3. TextCNN的搭建和应用

一、LeNet的搭建和应用

LeNet是早期最经典的卷积神经网络,由 Yann LeCun 等人在 1998 年提出,最初用于手写数字识别(MNIST 数据集),取得了十分显著的效果,其网络结构如图所示:

图片来自 LeNet - Wikipedia

 代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data as Data
import torchvision
from torchvision import models, transforms, datasetsfrom process import classify  # procss的代码见:https://blog.csdn.net/moyao_miao/article/details/141466047class LeNet(nn.Module):"""LeNet模型"""def __init__(self, size):super().__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)s = size // 4 - 3self.fc1 = nn.Linear(16 * s * s, 120)self.fc2 = nn.Linear(120, 84)self.output = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))output = self.output(x)return output

使用MNIST数据集训练模型:

if __name__ == '__main__':train_data = torchvision.datasets.MNIST(root=r"C:\Users\57158\data\MNIST",train=True,transform=transforms.ToTensor(),download=False,)test_data = torchvision.datasets.MNIST(root=r"C:\Users\57158\data\MNIST",train=False,transform=transforms.ToTensor(),download=False,)model = LeNet(28)optimizer = Adam(model.parameters(), lr=0.0003)criterion = nn.CrossEntropyLoss()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")classify((train_data, test_data),model,optimizer,criterion,batch_size=64,epochs=5,device=device,)

分类效果:

 二、微调预训练的VGG16网络

VGG网络是由牛津大学视觉几何组(Visual Geometry Group)在2014年提出的一种卷积神经网络架构。它在当年的ImageNet图像分类挑战赛中取得了优异的成绩,并因其简洁的架构和良好的性能而广受欢迎。其网络结构如图所示:

VGG系列网络结构,图片来自 VGGNet-16 Architecture: A Complete Guide (kaggle.com)

VGG网络的主要特点:

  1. 深度:VGG网络以其深度著称,最深的版本(VGG16和VGG19)分别有16层和19层。这种深度使得网络能够学习到更复杂的特征。
  2. 小卷积核:VGG网络使用3x3的小卷积核,而不是之前常用的更大的卷积核(如7x7)。小卷积核的优势在于可以减少参数数量,同时通过叠加多个3x3卷积层可以模拟更大的感受野。
  3. 固定结构:VGG网络的结构非常规整,主要由卷积层和全连接层组成。卷积层通常使用ReLU激活函数,全连接层后面通常接一个softmax层用于分类。
  4. 池化层:VGG网络在每几个卷积层之后会插入一个最大池化层(Max Pooling),用于降低特征图的尺寸,减少计算量,并增强特征的平移不变性。

以VGG16为例,其结构如图: 

VGG16网络结构,图片来自 VGGNet-16 Architecture: A Complete Guide (kaggle.com)

尽管VGG网络具有简洁的结构和良好的性能,但由于其网络较深、参数较多,VGG网络的计算量和内存消耗都比较大,导致从头开始训练比较费时,而且对算力的要求也比较高。幸运的是PyTorch提供了预训练好的网络模型可供调用,开发者在其基础上微调即可快速搭建自己的网络。

一个通过微调预训练的VGG16网络用于分类10种猴子的图像分类器代码如下:

class MyVggModel(nn.Module):"""自定义的VGG16模型"""def __init__(self):super().__init__()# 加载预训练的vgg16模型vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features# 冻结参数for param in vgg.parameters():param.requires_grad_(False)# 预训练的vgg16的特征提取层self.vgg = vgg# 自定义的全连接层self.classifier = nn.Sequential(nn.Linear(25088, 512),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(256, 10),nn.Softmax(dim=1),)# 定义网络的向前传播路径def forward(self, x):x = self.vgg(x)x = x.view(x.size(0), -1)output = self.classifier(x)return output

数据集来源:10 Monkey Species (kaggle.com)

数据预处理:

if __name__ == '__main__':# 对训练集的预处理train_data_transforms = transforms.Compose([transforms.RandomResizedCrop(224),  # 随机将图像裁剪为224*224transforms.RandomHorizontalFlip(),  # 随机水平翻转图像transforms.ToTensor(),  # 转化为张量并归一化至[0,1]transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 图像标准化处理])# 对测试集的预处理test_data_transforms = transforms.Compose([transforms.Resize(256),  # 将图像缩放为256*256transforms.CenterCrop(224),  # 将图像从中心裁剪为224*224transforms.ToTensor(),  # 转化为张量并归一化至[0,1]transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 图像标准化处理])# 读取图像train_data_dir = r'C:\Users\57158\data\10-monkey-species\training\training'train_data = datasets.ImageFolder(train_data_dir, transform=train_data_transforms)test_data_dir = r'C:\Users\57158\data\10-monkey-species\validation\validation'test_data = datasets.ImageFolder(test_data_dir, transform=test_data_transforms)

训练模型:

    model = MyVggModel()optimizer = Adam(model.parameters(), lr=0.0003)criterion = nn.CrossEntropyLoss()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')classify((train_data, test_data),model,optimizer,criterion,batch_size=32,epochs=10,device=device,)

分类效果:

三、TextCNN的搭建和应用

TextCNN是一种用于文本分类的卷积神经网络模型,它由 Yoon Kim 在 2014 年提出。TextCNN 通过利用卷积层和池化层来捕捉文本中的局部特征,从而实现高效的文本分类。其结构如图:

TextCNN的结构,图片来自 1510.03820 (arxiv.org)

TextCNN 的基本结构包括以下几个部分:

  1. 嵌入层(Embedding Layer):将输入的词索引转换为词向量。这些词向量可以是预训练的,也可以是随机初始化的。
  2. 卷积层(Convolutional Layer):使用多个不同大小的卷积核来提取不同长度的特征。每个卷积核会在输入的词向量序列上滑动,生成特征图。
  3. 池化层(Pooling Layer):通常使用最大池化(Max Pooling)来提取每个特征图中的最大值,从而减少特征维度并保留最重要的特征。
  4. 全连接层(Fully Connected Layer):将池化后的特征向量输入到全连接层中,进行分类。
  5. 输出层(Output Layer):通常使用 softmax 函数来输出每个类别的概率。

一个通用的TextCNN网络代码如下:

import re
from functools import partial
from typing import Iteratorimport pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
import torchtext;torchtext.disable_torchtext_deprecation_warning()
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import numericalize_tokens_from_iterator
from nltk.corpus import stopwordsfrom process import classify
# procss的代码见:https://blog.csdn.net/moyao_miao/article/details/141466047class TextCNN(nn.Module):"""TextCNN模型"""def __init__(self, vocab_size: int, embedding_dim: int, num_filters: int,filter_sizes: Iterator, num_classes: int, dropout: float = 0.5):"""初始化TextCNN模型:param vocab_size:词典大小:param embedding_dim:词向量维度:param num_filters:卷积核个数:param filter_sizes:卷积核尺寸:param num_classes:输出的维度:param dropout:Dropout概率"""super().__init__()# 嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim)# 卷积层self.convs = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim)) for fs in filter_sizes])# 最大池化层self.pool = nn.AdaptiveMaxPool1d(1)# Dropout层self.dropout = nn.Dropout(dropout)# 全连接输出层self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)# 定义网络的向前传播路径def forward(self, text):# text:(batch_size,MAX_LENGTH)embedded = self.embedding(text).unsqueeze(1)  # embedded:(batch_size,1,MAX_LENGTH,embedding_dim)conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]  # conved[n]:(batch_size,num_filters,MAX_LENGTH-filter_sizes[n]+1)pooled = [self.pool(conv).squeeze(2) for conv in conved]  # pooled[n]:(batch_size,num_filters)cat = self.dropout(torch.cat(pooled, dim=1))  # cat:(batch_size,num_filters*len(filter_sizes))return self.fc(cat)

使用IMDB数据集来训练模型:IMDB Dataset of 50K Movie Reviews (kaggle.com)

数据预处理一,文本清洗:

punctuation_regex = re.compile(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]')
stopwords_regex = re.compile('\\b(' + '|'.join(stopwords.words('english')) + ')\\b')
clean_ops = [str.lower,  # 转化为小写partial(re.sub, '<br /><br />', ' '),  # 去除换行符partial(re.sub, '\d+', ''),  # 去除数字partial(punctuation_regex.sub, ''),  # 去除符号partial(stopwords_regex.sub, ''),  # 去除停用词str.strip,  # 去除两端空格
]
def clean_text(s: str, ops: Iterator) -> str:"""模块化的文本清洗函数:param s: 待清洗的文本:param ops: 清洗函数列表:return: 清洗后的文本"""for op in ops:s = op(s)return stokenizer = get_tokenizer('spacy')
def token_gen(texts):"""生成token迭代器:param texts: 文本列表:return: token迭代器"""for text in texts:yield tokenizer(text)if __name__ == '__main__':df = pd.read_csv(r'C:\Users\57158\data\IMDB Dataset.csv')df['review'] = df['review'].apply(clean_text, ops=clean_ops)  # 文本清洗df['sentiment'] = df['sentiment'].apply(lambda x: 1 if x == 'positive' else 0)  # 标签转换df.to_csv(r'IMDB_Dataset_clean.csv', index=False)

数据预处理二,构建数字化文本矩阵:

    VOCAB_SIZE = 20000MAX_LENGTH = 100df = pd.read_csv(r'IMDB_Dataset_clean.csv')vocab = build_vocab_from_iterator(token_gen(df['review']), specials=['<UNK>'], max_tokens=VOCAB_SIZE)  # 构建词典vocab.set_default_index(vocab['<UNK>'])  # 设置默认索引处理未知词sequence = numericalize_tokens_from_iterator(vocab=vocab, iterator=token_gen(df['review']))  # 数字化文本token_ids = [torch.tensor(list(x)) for x in sequence]  # 将数字化的文本转换为tensorpadded_text = pad_sequence(token_ids, batch_first=True, padding_value=0)[:, :MAX_LENGTH]  # 填充文本并截断

注意:新版本的torchtext接口较旧版本变化比较大,很多旧版本的用法已经失效了,24年以前的torchtext教程就不用再看了。

训练模型:

    model = TextCNN(vocab_size=len(vocab),embedding_dim=MAX_LENGTH,num_filters=100,filter_sizes=[3, 4, 5],num_classes=2,)model.embedding.weight.data[vocab['<UNK>']] = torch.zeros(MAX_LENGTH)model.embedding.weight.data[vocab['<PAD>']] = torch.zeros(MAX_LENGTH)optimizer = Adam(model.parameters())criterion = nn.CrossEntropyLoss()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')classify((padded_text, torch.tensor(df['sentiment'])),model,optimizer,criterion,batch_size=32,epochs=3,device=device,to_tensor=False,)

分类效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/410636.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R语言绘制可用于论文发表的生存曲线图|科研绘图·24-08-25

小罗碎碎念 有关于生存曲线的基本概念&#xff08;例如删失事件的定义&#xff09;和绘图的详细教程我已经在5月的推文中介绍过了&#xff0c;有需求的同学欢迎前去考古。 R语言绘制生存分析曲线从概念到实战的保姆级教程&#xff5c;2024-05-12 https://mp.weixin.qq.com/s/Z…

SQL进阶技巧:如何按任意时段分析时间区间问题? | 分区间讨论【左、中、右】

目录 0 场景描述 1 数据准备 2 问题分析 方法1:分情况讨论,找出重叠区间 方法2:暴力美学法。按区间展开成日期明细表 3 拓展案例 4小结 0 场景描述 现有用户还款计划表 user_repayment ,该表内的一条数据,表示用户在指定日期区间内 [date_start, date_end] ,每天…

秋招突击——8/21——知识补充——计算机网络——cookie、session和token

文章目录 引言正文Cookie——客户端存储和管理Session——服务端存储和管理Token补充签名和加密的区别常见的加密算法和签名算法 面试题1、HTTP用户后续的操作&#xff0c;服务端如何知道属于同一个用户&#xff1f;如果服务端是一个集群机器怎么办&#xff1f;2、如果禁用了Co…

【Python 千题 —— 基础篇】简易图书管理系统

Python 千题持续更新中 …… 脑图地址 👉:⭐https://twilight-fanyi.gitee.io/mind-map/Python千题.html⭐ 题目描述 题目描述 编写一个面向对象的程序,模拟一个图书管理系统。要求定义一个 Book 类,具有基本的书籍信息功能;然后,创建一个 Library 类,用于管理多个 B…

Vue3搜索框(InputSearch)

效果如下图&#xff1a;在线预览 APIs InputSearch 参数说明类型默认值width搜索框宽度&#xff0c;单位 pxstring | number‘100%’icon搜索图标boolean | slottruesearch搜索按钮&#xff0c;默认时为搜索图标string | slotundefinedsearchProps设置搜索按钮的属性&#xf…

【Qt】容器类控件GroupBox

容器类控件GroupBox 使用QGroupBox实现一个带有标题的分组框&#xff0c;可以把其他的控件放在里面里面作为一组&#xff0c;这些内部的控件的父元素也就不是this了。 其目的只是为了让界面看起来更加好看&#xff0c;例如当一个界面比较复杂的时候&#xff0c;包含了很多的控…

APP封装安装配置参考说明

APP封装安装配置参考说明 一, 环境准备 宝塔环境 nginx php5.6 mysql5.6 java-openjdk1.8 apktool 1,安装 nginx,php,mysql自行安装 java-openjdk1.8 安装 推荐使用命令行安装 1.1 yum install java-1.8.0-openjdk1.2 yum install -y java-1.8.0-openjdk-devel1.3 设置…

Unity | 性能标准分析工具图形API简介

目录 一、相关术语 1.物理页 2.PSS内存 3.Reserved Total 二、耗时推荐值 三、内存推荐值 四、分析工具 1.Profiler &#xff08;1&#xff09;Profiler各平台对比 &#xff08;2&#xff09;构建到目标平台 &#xff08;3&#xff09;Frame数量修改 &#xff08;4…

天宝TBCTrimble Business Center中文版本下载安装使用介绍

天宝TBC&#xff1a;测绘之道&#xff0c;尽在其中 引言 昔日杜甫&#xff0c;忧国忧民&#xff0c;今朝我辈&#xff0c;测绘天下。天宝TBC&#xff0c;乃测绘之利器&#xff0c;助我等行走于山川河流之间&#xff0c;绘制天地之图。此文将以杜甫之笔&#xff0c;述说TBC之妙…

【数据结构】栈(stack)

目录 栈的概念 栈的方法 栈的实现 数组实现 push方法 压栈 pop方法 出栈 peek方法 获取栈顶元素 size方法 获取有效元素个数 链表实现 结尾 完整代码 数组实现栈代码 双向链表实现栈代码 栈的概念 栈是一种特殊的线性表&#xff0c;只允许在 固定的一段 进行插入…

kafka发送消息-生产者发送消息的分区策略(消息发送到哪个分区中?是什么策略)

生产者发送消息的分区策略&#xff08;消息发送到哪个分区中&#xff1f;是什么策略&#xff09; 1、默认策略&#xff0c;程序自动计算并指定分区1.1、指定key&#xff0c;不指定分区1.2、不指定key&#xff0c;不指定分区 2、轮询分配策略RoundRobinPartitioner2.1、创建配置…

使用idea快速创建springbootWeb项目(springboot+springWeb+mybatis-Plus)

idea快速创建springbootWeb项目 详细步骤如下 1&#xff09;创建项目 2&#xff09;选择springboot版本 3&#xff09;添加web依赖 4&#xff09;添加Thymeleaf 5&#xff09;添加lombok依赖 然后点击create进入下一步 双击pom.xml文件 6&#xff09;添加mybatis-plus依赖 …

【系统分析师】-案例篇-数据库

1、分布式数据库 1&#xff09;请用300字以内的文字简述分布式数据库跟集中式数据库相比的优点。 &#xff08;1&#xff09;坚固性好。由于分布式数据库系统在个别结点或个别通信链路发生故障的情况下&#xff0c;它仍然可以降低级别继续工作&#xff0c;系统的坚固性好&…

Ubuntu搭建FTP服务器

目录 1.ftp简介 2.vsftpd 2.1.介绍 2.2.安装与卸载 2.3.综合案例 - 本地用户模式 2.4.1.创建FTP用户 2.4.2.配置vsftpd 2.4.3.配置防火墙 1.ftp简介 一般来讲&#xff0c;人们将计算机联网的首要目的就是获取资料&#xff0c;而文件传输是一种非常重要的获取资料的方…

Docker 修改镜像源

由于docker hub 被禁&#xff0c;导致 docker 拉取镜像失败&#xff0c;解决办法就是使用国内的镜像源&#xff0c;目前国内的镜像源还是很多的&#xff0c;例如阿里云、腾讯云、华为云等等&#xff0c;下面演示一个更换成阿里云的步骤。 1. 阿里云获取加速地址 1.1 首先登录阿…

Git —— 1、Windows下安装配置git

Git简介 Git 是一个免费的开源分布式版本控制系统&#xff0c;旨在处理从小型到 快速高效的超大型项目。 Git 易于学习&#xff0c;占用空间小&#xff0c;性能快如闪电。 它超越了 Subversion、CVS、Perforce 和 ClearCase 等 SCM 工具 具有 cheap local branching、 方便的暂…

HIVE 数据仓库工具之第一部分(讲解部署)

HIVE 数据仓库工具 一、Hive 概述1.1 Hive 是什么1.2 Hive 产生的背景1.3 Hive 优缺点1.3.1 Hive的优点1.3.2 Hive 的缺点 1.4 Hive在Hadoop生态系统中的位置1.5 Hive 和 Hadoop的关心 二、Hive 原理及架构2.1 Hive 的设计原理2.2 Hive 特点2.3 Hive的体现结构2.4 Hive的运行机…

Linux 配置wireshark 分析thread 使用nRF-Sniffer dongle

Linux 配置wireshark nRF-Sniffer-for-802.15.4 1.下载固件和配置文件 https://github.com/NordicSemiconductor/nRF-Sniffer-for-802.15.4 2.烧写固件 使用nRF Connect for Desktop 中的 programmer 4.3烧写 https://www.nordicsemi.com/Products/Development-tools/nrf-conne…

【layUI】点击导出按钮,导出excel文件

要实现的功能如下&#xff1a;根据执行状态判断是否可以导出。如果可以导出&#xff0c;点击导出&#xff0c;在浏览器里下载对应的文件。 代码实现 html里&#xff1a; <table class"layui-hide" id"studentTable" lay-filter"studentTable&…

Dubbo3框架概述

1 什么是分布式系统? 《分布式系统原理与范型》定义: “分布式系统是若干独立计算机的集合,这些计算机对于用户来说就像单个相关系统” 分布式系统(distributed system)是建立在网络之上的软件系统。 简单来说:多个(不同职责)人共同来完成一件事! 任何一台服务器都无法…