AI-基本概念-多层感知器模型/CNN/RNN/自注意力模型

1 需求

神经网络

……


深度学习

……


深度学习包含哪些神经网络:

  • 全连接神经网络
  • 卷积神经网络
  • 循环神经网络
  • 基于注意力机制的神经网络


2 接口


3 CNN

在这个示例中:

 
  • 首先定义了一个简单的卷积神经网络SimpleCNN,它包含两个卷积层、两个池化层和两个全连接层。
  • 然后通过torchvision库加载了 MNIST 数据集,并进行了数据预处理。
  • 接着使用交叉熵损失函数和随机梯度下降优化器对模型进行了 10 个周期的训练。
  • 最后在测试集上对模型进行了测试,计算了模型的准确率。这是一个基础的 PyTorch CNN 应用示例,你可以根据实际需求修改模型结构、数据和训练参数等。

第一步,定义卷积神经网络(CNN)模型

import torch
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 第一个卷积层,输入通道为1(灰度图像),输出通道为32,卷积核大小为3x3self.conv1 = nn.Conv2d(1, 32, kernel_size=3)# 第一个卷积层后的激活函数ReLUself.relu1 = nn.ReLU()# 第一个最大池化层,池化核大小为2x2self.pool1 = nn.MaxPool2d(kernel_size=2)# 第二个卷积层,输入通道为32,输出通道为64,卷积核大小为3x3self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)# 全连接层,将卷积层输出的特征图展平后连接到该层,输入大小为64 * 6 * 6,输出大小为128self.fc1 = nn.Linear(64 * 6 * 6, 128)self.relu3 = nn.ReLU()# 最后一个全连接层,用于分类,输出大小为10(假设是10分类问题)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)# 将特征图展平x = x.view(-1, 64 * 6 * 6)x = self.fc1(x)x = self.relu3(x)x = self.fc2(x)return x

第二步,准备数据(以 MNIST 数据集为例)

import torchvision
import torchvision.transforms as transforms# 定义数据转换,将图像转换为张量并进行归一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
# 下载并加载训练数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)
# 下载并加载测试数据集
testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=2)

第三步,训练模型

# 创建模型实例
model = SimpleCNN()
# 定义损失函数(交叉熵损失)和优化器(随机梯度下降)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练循环
for epoch in range(10):  # 进行10个训练周期running_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取输入数据和标签inputs, labels = data# 梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 累计损失running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

第四步,测试模型

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = data# 模型预测outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the test set: {100 * correct / total}%')

4 参考资料

神经网络——最易懂最清晰的一篇文章-CSDN博客

  1. 多层感知机(Multilayer Perceptron,MLP)
    • 结构特点:是一种简单的前馈神经网络,由输入层、一个或多个隐藏层和输出层组成。神经元之间全连接,即每个神经元与相邻层的所有神经元都有连接。例如,在一个用于手写数字识别的简单 MLP 中,输入层接收图像像素值,经过隐藏层的非线性变换后,输出层输出各个数字类别对应的概率。
    • 应用场景:广泛应用于分类和回归问题,如简单的图像分类、数据预测等。在自然语言处理领域可用于文本分类,在金融领域用于股票价格预测等。
  2. 卷积神经网络(Convolutional Neural Network,CNN)
    • 结构特点:主要由卷积层、池化层和全连接层组成。卷积层通过卷积核提取数据的局部特征,池化层进行下采样以减少数据维度和计算量,全连接层用于分类或回归等任务。例如在人脸识别任务中,卷积层可以提取人脸五官轮廓等特征。
    • 应用场景:在计算机视觉领域占据主导地位,用于图像分类(如识别图片中的物体是猫还是狗)、目标检测(检测图像中物体的位置和类别)、语义分割(将图像中的每个像素分类到不同语义类别)等。也在音频处理等领域有应用,如语音识别中的声学模型。
  3. 循环神经网络(Recurrent Neural Network,RNN)
    • 结构特点:具有循环连接,能够处理序列数据。在每个时间步,神经元接收当前输入和上一个时间步的隐藏状态,经过处理后输出当前时间步的隐藏状态和预测结果。例如在机器翻译中,RNN 可以逐词处理输入句子和生成翻译后的句子。
    • 应用场景:自然语言处理领域的文本生成、机器翻译、情感分析等任务,以及时间序列预测,如股票走势预测、气象数据预测等。不过,传统 RNN 存在梯度消失和梯度爆炸问题。
  4. 长短期记忆网络(Long - Short Term Memory,LSTM)和门控循环单元(Gated Recurrent Unit,GRU)
    • 结构特点(以 LSTM 为例):是 RNN 的变体,通过特殊的门控机制(输入门、遗忘门和输出门)来控制信息的流动,能够有效解决 RNN 中的梯度消失和梯度爆炸问题,更好地处理长序列数据。例如在长篇小说生成任务中,LSTM 可以有效地利用前文信息生成后续内容。GRU 结构相对更简单,将遗忘门和输入门合并为一个更新门,在性能上和 LSTM 类似,并且计算效率更高。
    • 应用场景:和 RNN 类似,主要用于自然语言处理中的长文本处理、语音识别中的语音序列处理、时间序列分析等需要处理长序列数据的任务。
  5. 生成对抗网络(Generative Adversarial Network,GAN)
    • 结构特点:由生成器和判别器两个神经网络组成。生成器的任务是生成尽可能逼真的数据,判别器的任务是区分真实数据和生成器生成的数据。两者通过对抗训练的方式不断提高性能,最终生成器能够生成高质量的假数据。例如在图像生成任务中,生成器可以根据噪声生成看起来像真实照片的图像。
    • 应用场景:图像生成(如生成高分辨率的风景照片)、数据增强(为训练数据集生成新的样本)、风格迁移(将一种图像风格转换为另一种风格)等。
  6. 自编码器(Auto - Encoder)
    • 结构特点:由编码器和解码器组成。编码器将输入数据压缩成低维的表示(编码),解码器将这个编码还原为尽可能接近原始输入的数据。例如,在图像压缩任务中,编码器将高分辨率图像转换为低维向量,解码器再将这个向量还原为图像。
    • 应用场景:数据降维、图像去噪、特征提取等。例如,在医学影像处理中,可以利用自编码器提取有价值的特征用于疾病诊断。
  7. Transformer 架构
    • 结构特点:基于自注意力机制(Self - Attention),摒弃了传统的循环结构,能够并行计算,大大提高了训练和推理速度。在处理序列数据时,通过计算每个位置与其他位置的相关性来提取特征。例如在自然语言处理中的 BERT 模型,就是基于 Transformer 架构,能够有效捕捉句子中单词之间的语义关系。
    • 应用场景:自然语言处理领域的预训练语言模型(如 GPT 系列、BERT 系列)、机器翻译等任务。在计算机视觉领域也有基于 Transformer 的模型用于图像分类等任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/460898.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leaflet查询矢量瓦片偏移的问题

1、问题现象 使用Leaflet绘制工具查询出来的结果有偏移 2、问题排查 1)Leaflet中latLngToContainerPoint和latLngToLayerPoint的区别 2)使用Leaflet查询需要使用像素坐标 3)经排查发现,container获取的坐标是地图容器坐标&…

Vue生成名片二维码带logo并支持下载

一、需求 生成一张名片,名片上有用户信息以及二维码,名片支持下载功能(背景样式可更换,忽略本文章样图样式)。 二、参考文章 这不是我自己找官网自己摸索出来的,是借鉴各位前辈的,学以致用&am…

如何利用网站进行仿牌运营?

对于很多人来说,仿牌网站的运营是一项充满挑战的任务,很多初学者对如何开始感到困惑,甚至不清楚仿牌网站的运作机制。此外,搜索引擎对新网站的审核期也使得许多站长倍感压力。那么,如何才能在这一过程中有效地进行SEO优…

数字IC开发:布局布线

数字IC开发:布局布线 前端经过DFT,综合后输出网表文件给后端,由后端通过布局布线,将网表转换为GDSII文件;网表文件只包含单元器件及其连接等信息,GDS文件则包含其物理位置,具体的走线&#xff1…

传智杯 第六届-复赛-C

题目描述: 小红有一个数组,她每次可以选择数组的一个元素 xxx ,将这个元素分成两个元素 aaa 和 bbb ,使得 abxabxabx。请问小红最少需要操作多少次才可以使得数组的所有元素都相等。 输入描述: 第一行输入一个整数 n(1≤n≤10^5)…

华为配置 之 GVRP协议

目录 简介: 配置GVRP: 总结: 简介: GVRP(GARP VLAN Registration Protocol),称为VLAN注册协议,是用来维护交换机中的VLAN动态注册信息,并传播该信息到其他交换机中&…

外包干了7天,技术明显退步。。。。。

先说一下自己的情况,本科生,22年通过校招进入南京某软件公司,干了接近2年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了2年的功能测试&…

openGauss开源数据库实战十

文章目录 任务十 openGauss逻辑结构:数据库管理任务目标实施步骤一、登录到openGauss二、创建数据库三、查看数据库集群中有哪些数据库四、查看数据库默认表空间的信息五、查看数据库下有哪些模式六、查看数据库下有哪些表七、修改数据库的默认表空间八、重命名数据库九、删除数…

H3C OSPF配置

OSPF配置实验 实验拓扑图 实验需求 1.配置IP地址 2.分区域配置OSPF&#xff0c;实现全网互通 3.为了路由结构稳定&#xff0c;要求路由器使用环回口作为Router-id&#xff0c;ABR的环回口宣告进骨干区域 实验配置 1.配置IP地址 R1&#xff1a; <H3C>system-view …

飞桨首创 FlashMask :加速大模型灵活注意力掩码计算,长序列训练的利器

在 Transformer 类大模型训练任务中&#xff0c;注意力掩码&#xff08;Attention Mask&#xff09;一方面带来了大量的冗余计算&#xff0c;另一方面因其 O ( N 2 ) O(N^2) O(N2)巨大的存储占用导致难以实现长序列场景的高效训练&#xff08;其中 N N N为序列长度&#xff09;…

乘云而上,OceanBase再越山峰

一座山峰都是一个挑战&#xff0c;每一次攀登都是一次超越。 商业数据库时代&#xff0c;面对国外数据库巨头这座大山&#xff0c;实现市场突破一直都是中国数据库产业多年夙愿&#xff0c;而OceanBase在金融核心系统等领域的攻坚克难&#xff0c;为产业突破交出一副令人信服的…

为什么要使用Golang以及如何入门

什么是golang&#xff1f; Go是一种开放源代码的编程语言&#xff0c;于2009年首次发布&#xff0c;由Google的Rob Pike&#xff0c;Robert Griesemer和Ken Thompson开发。基于C的语法&#xff0c;它进行了一些更改和改进&#xff0c;以安全地管理内存使用&#xff0c;管理对象…

《文心一言插件设计与开发》赛题三等奖方案 | NoteTable

一年一度的 CCF大数据与计算智能大赛&#xff08;简称2024 CCF BDCI大赛&#xff09;又开始啦~~ 程序员们可冲一波嗷~ 大赛地址&#xff1a;http://go.datafountain.cn/6506 现在我们再次释放往届获奖方案&#xff0c; 为新一届大赛的同学们提供一些方案和灵感参考~ 大家借鉴借…

el-dialog支持全局拖拽功能

1.首先在全局的组件实现拖拽功能&#xff0c;结构如下 dialogDrag.vue的内容 <script>export default {mounted() {// 获取当前的dialog及其headerlet aimDialog this.$el.getElementsByClassName(el-dialog)[0];let aimHeader this.$el.getElementsByClassName(el-d…

XCode16中c++头文件找不到解决办法

XCode16中新建Framework&#xff0c;写完自己的c代码后&#xff0c;提示“<string> file not found”等诸如此类找不到c头文件的错误。 工程结构如下&#xff1a; App是测试应用&#xff0c;BoostMath是Framework。基本结构可以参考官方demo&#xff1a;Mix Swift and …

开源代码管理平台Gitlab如何本地化部署并实现公网环境远程访问私有仓库

文章目录 前言1. 下载Gitlab2. 安装Gitlab3. 启动Gitlab4. 安装cpolar5. 创建隧道配置访问地址6. 固定GitLab访问地址6.1 保留二级子域名6.2 配置二级子域名 7. 测试访问二级子域名 前言 本文主要介绍如何在Linux CentOS8 中搭建GitLab私有仓库并且结合内网穿透工具实现在公网…

JavaEE初阶---网络原理(四)--IP协议/DNS协议

文章目录 1.初识网络层&#xff08;了解即可&#xff09;2.地址管理2.1动态分配2.2网络地址转换2.3IP-v6最终解 3.网段划分4.以太网协议--数据链路层5.DNS应用层协议 1.初识网络层&#xff08;了解即可&#xff09; 网络层做的事情就是下面的两个&#xff1a; 1&#xff09;地…

4.2-6 使用Hadoop WebUI

文章目录 1. 查看HDFS集群状态1.1 端口号说明1.2 用主机名访问1.3 主节点状态1.4 用IP地址访问1.5 查看数据节点 2. 操作HDFS文件系统2.1 查看HDFS文件系统2.2 在HDFS上创建目录2.3 上传文件到HDFS2.4 删除HDFS文件和目录 3. 查看YARN集群状态4. 实战总结 1. 查看HDFS集群状态 …

EMS专题 | 5个必须知道的温度监测系统入门知识

在保护温度敏感资产方面&#xff0c;可靠的温度监测技术扮演着至关重要的角色。为了帮助您深入了解这一关键技术&#xff0c;我们特别推出了EMS&#xff08;环境监测系统&#xff09;专题文章系列。内容将由浅入深&#xff0c;从基础原理到实际应用&#xff0c;从行业标准到解决…

代码随想录-字符串-反转字符串中的单词

题目 题解 法一:纯粹为了做出本题&#xff0c;暴力解 没有技巧全是感情 class Solution {public String reverseWords(String s) {//首先去除首尾空格s s.trim();String[] strs s.split("\\s");StringBuilder sb new StringBuilder();//定义一个公共的字符反转…