19.神经网络 - 线性层及其他层介绍

神经网络 - 线性层及其他层介绍

1.批标准化层–归一化层(不难,自学看官方文档)

Normalization Layers

torch.nn — PyTorch 1.10 documentation

BatchNorm2d — PyTorch 1.10 documentation

对输入采用Batch Normalization,可以加快神经网络的训练速度

CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
# num_features C-输入的channel

image-20240808153449423

# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)  # 正则化层num_feature等于channel,即100
input = torch.randn(20, 100, 35, 45)   #batch_size=20,100个channel,35x45的输入
output = m(input)

image-20240808153318952

2.Recurrent Layers(特定网络中使用,自学)

RNN、LSTM等,用于文字识别中,特定的网络结构

torch.nn — PyTorch 1.13 documentation

image-20240808151920684

3.Transformer Layers(特定网络中使用,自学)

特定网络结构

torch.nn — PyTorch 1.13 documentation

image-20240808151936485

4.Linear Layers–线性层(本节讲解)–使用较多

网站地址:Linear — PyTorch 1.10 documentation

img

d代表特征数,L代表神经元个数 K和b在训练过程中神经网络会自行调整,以达到比较合理的预测

image-20240808152000232

image-20240808152017760

下面以一个简单的网络结果VGG16模型为例

5.代码实例 vgg16 model

img

flatten 摊平

torch.flatten — PyTorch 1.10 documentation

# Example
>>> t = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])   #3个中括号,所以是3维的
>>> torch.flatten(t)  #摊平
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)  #变为1行
tensor([[1, 2, 3, 4],[5, 6, 7, 8]])
  • reshape():可以指定尺寸进行变换
  • flatten():变成1行,摊平
output = torch.flatten(imgs)
# 等价于
output = torch.reshape(imgs,(1,1,1,-1))for data in dataloader:imgs,targets = dataprint(imgs.shape)  #torch.Size([64, 3, 32, 32])output = torch.reshape(imgs,(1,1,1,-1))  # 想把图片展平print(output.shape)  # torch.Size([1, 1, 1, 196608])output = tudui(output)print(output.shape)  # torch.Size([1, 1, 1, 10])for data in dataloader:imgs,targets = dataprint(imgs.shape)  #torch.Size([64, 3, 32, 32])output = torch.flatten(imgs)   #摊平print(output.shape)   #torch.Size([196608])output = tudui(output)print(output.shape)   #torch.Size([10])

我们想实现下面这个:

image-20240808152109358

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.linear1 = Linear(196608, 10)def forward(self, input):output = self.linear1(input)return outputtudui = Tudui()
writer = SummaryWriter("logs")
step = 0for data in dataloader:imgs, targets = dataprint(imgs.shape)  # torch.Size([64, 3, 32, 32])writer.add_images("input", imgs, step)output = torch.reshape(imgs,(1,1,1,-1))  # 方法一:用reshape把图片拉平,另一种办法直接用torch.flatten(imgs)摊平# print(output.shape)  # torch.Size([1, 1, 1, 196608])# output = tudui(output)# print(output.shape)  # torch.Size([1, 1, 1, 10])#output = torch.flatten(imgs)  #方法二 摊平print(output.shape)  # torch.Size([196608])output = tudui(output)print(output.shape)  # torch.Size([10])writer.add_images("output", output, step)step = step + 1

image-20240808161700359

运行后在 terminal 里输入:

tensorboard --logdir=logs

运行结果如下:

image-20240808161529487

6.Dropout Layers(不难,自学)

Dropout — PyTorch 1.10 documentation

在训练过程中,随机把一些 input(输入的tensor数据类型)中的一些元素变为0,变为0的概率为p

目的:防止过拟合

image-20240808152145889

7.Sparse Layers(特定网络中使用,自学)

Embedding

Embedding — PyTorch 1.10 documentation

用于自然语言处理

8.Distance Functions

计算两个值之间的误差

torch.nn — PyTorch 1.13 documentation

image-20240808152343889

9. Loss Functions

loss 误差大小

torch.nn — PyTorch 1.13 documentation

image-20240808152404986

  1. pytorch提供的一些网络模型

    图片相关:torchvision torchvision.models — Torchvision 0.11.0 documentation
    分类、语义分割、目标检测、实例分割、人体关键节点识别(姿态估计)等等

    文本相关:torchtext 无
    语音相关:torchaudio torchaudio.models — Torchaudio 0.10.0 documentation

下一节:Container ——> Sequential(序列)

hvision 0.11.0 documentation
分类、语义分割、目标检测、实例分割、人体关键节点识别(姿态估计)等等

文本相关:torchtext   无
语音相关:torchaudio  torchaudio.models — Torchaudio 0.10.0 documentation

下一节:Container ——> Sequential(序列)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411346.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美发店会员系统设计解读之规格选择-SAAS本地化及未来之窗行业应用跨平台架构

一、请求产品信息 $.ajax({type:"get", //请求方式async:true, //是否异步url:"服务器",dataType:"json", //跨域json请求一定是jsonpjsonp: "cwpd_showData_dy_spec", //跨域请求的参数名,默认是callback//js…

从学习到工作,2024年不可或缺的翻译助手精选

翻译工具利用先进的机器学习和自然语言处理技术,能够迅速将一种语言的文档转换为另一种语言,极大地促进了信息的无障碍流通。接下来,我们将介绍几款功能强大、操作简便的类似deepl翻译的工具,帮助你轻松应对各种翻译需求。 第一款…

pymysql cursor使用教程

Python之PyMySQL的使用: 在python3.x中,可以使用pymysql来MySQL数据库的连接,并实现数据库的各种操作,本次博客主要介绍了pymysql的安装和使用方法。 PyMySQL的安装 一、.windows上的安装方法: 在python3.6中&…

基于SpringBoot的校园闲置物品交易管理系统

基于SpringBootVue的校园闲置物品交易管理系统【附源码文档】、前后端分离 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/Ecilpse、Navicat、Maven 系统展示 摘要 基于SpringBoot与Vue的校…

Linux驱动开发—创建总线,创建属性文件

文章目录 1.什么是BUS?1.1总线的主要概念1.2总线的操作1.3总线的实现 2.创建总线关键结构体解析2.1注册总线到系统2.2 struct bus_type *bus 解析 3.实验结果分析1. devices 目录2. drivers 目录3. drivers_autoprobe 文件4. drivers_probe 文件5. uevent 文件 4.在…

vscode远程连接服务器并根据项目配置setting.json

vscode连接好远程服务器,打开项目文件,按下快捷键:CtrlShiftP 搜索setting.json 这边可以看到不同范围的setting.json,这边以文件夹(项目)为单位,即在打开的文件夹内创建setting.json&#xff…

axure9树形元件节点的添加

树形元件 | AxureChina 在需要添加节点处右键添加->添加子节点

World of Warcraft [CLASSIC][80][Grandel] Call to Arms: Strand of the Ancients

Call to Arms: Strand of the Ancients - Quest - 魔兽世界怀旧服CTM4.34《大地的裂变》数据库_大灾变85级魔兽数据库_ctm数据库 Call to Arms: Strand of the Ancients 战斗的召唤:远古海滩 打掉最后一个门【古代圣物之厅】,人跳进去就赢了 拿【炸弹】…

SpringBoot集成kafka-监听器注解

SpringBoot集成kafka-监听器注解 1、application.yml2、生产者3、消费者4、测试类5、测试 1、application.yml #自定义配置 kafka:topic:name: helloTopicconsumer:group: helloGroup2、生产者 package com.power.producer;import com.power.model.User; import com.power.uti…

UnQLite:多语言支持的嵌入式NoSQL数据库深入解析

文章目录 1. 引言2. Key/Value 存储接口2.1 关键函数2.2 使用示例2.3 高级操作:批量文件存储 3. 游标的使用4. UnQLite-Python使用示例4. UnQLite数据库引擎架构5.1 Key/Value存储层5.2 文档存储层5.3 可插拔的存储引擎5.4 事务管理器与分页模块5.5 虚拟文件系统 6.…

游戏开发设计模式之模板方法模式

目录 模板方法模式在游戏开发中的具体应用案例是什么? 如何在不同类型的游戏(如角色扮演游戏、策略游戏等)中实现模板方法模式? 模板方法模式与其他设计模式(如观察者模式、状态模式等)相比,…

物联网平台与边缘计算平台,ThingsKit与AIoTedge

物联网平台和边缘计算平台是现代智能系统中不可或缺的组成部分,它们共同支撑着设备的连接、数据的收集和智能分析等功能。ThingsKit和AIoTedge是两个专注于不同层面的平台,它们各自具有独特的特点和优势。 ThingsKit是一个运行在云端的通用物联网平台&am…

深度学习项目实践——qq聊天机器人(transformer)(一)原理介绍

文章目录 首先第一步——QQ是如何实现实时聊天数据传输过程1. 用户发送消息的开始2. 数据封装与加密3. 建立连接:WebSocket协议的应用4. 消息的传输过程5. 接收者获取消息6. 双向通信与实时性保障7. 保持连接与断线重连 第二步——聊天机器人是如何来接管QQ账号的组…

论文阅读笔记:RepViT: Revisiting Mobile CNN From Vit Perspective

文章目录 RepViT: Revisiting Mobile CNN From Vit Perspective动机现状问题 贡献实现Block设置独立的token融合器和通道融合器减少膨胀并增加宽度 宏观设计stem的早期卷积简单分类器整体阶段比率 微观设计内核大小选择Squeeze-and-excitation层放置网络架构 实验ImageNet-1K上…

Jmeter(十四)Jmeter分布式部署测试

单个接口测试,我们使用谷歌的插件postman 多个接口测试,我们使用Jmeter进行测试 一、使用工具测试 1、使用Jmeter对接口测试 首先我们说一下为什么用Posman测试后我们还要用Jmeter做接口测试,在用posman测试时候会发现的是一个接口一个接…

存储架构模式之复制架构

存储类问题处理框架图 故障:机器挂掉 灾难:自然灾害 多活:技术复杂度高、成本高 高可用的关键指标 stag1是正常状态,系统和业务都是正常的 stag2是故障状态,系统和业务都是异常的 stag3是系统恢复正常&#xff0c…

docker maven 构建的找不到 ClassNotFoundException

Exception in thread "main" java.lang.ClassNotFoundException: com.baimeidashu.springbootdemo1.Springbootdemo1Application 我用idea 自带的 maven 构建的jiar包没,没问题, 但是用 docker 镜像 maven:3.6.0-jdk-8-alpine 构建的就出问…

Oracle发邮件时SMTP服务器配置方法与步骤?

Oracle发邮件功能如何配置?如何优化Oracle发信性能? 为了实现自动化报告和通知,Oracle发邮件功能变得尤为重要。通过配置SMTP服务器,Oracle可以轻松地发送电子邮件。AokSend将详细介绍如何配置Oracle发邮件时的SMTP服务器&#x…

收藏夹里的“小网站”被误报违规不让上怎么办?如何将Chrome和Edge安装到 D 盘(含用户数据),重装系统也不会丢失收藏夹和密码?

当你用国产浏览器访问网站的时候,有时候会显示这个: 如果确实是违规网站,不让访问也没什么,但是很多都是误报啊,你这样直接来个大红横幅,还让人活不? 那遇到这种误报应当怎么办呢?有…

爆火的《黑神话:悟空》对LabVIEW软件开发的启示

近期,《黑神话:悟空》在全球范围内爆火,引发了游戏行业和玩家群体的广泛关注。作为一款由中国开发团队Game Science历时多年打造的动作角色扮演游戏,它的成功不仅源于卓越的技术创新和对中国传统文化的深度挖掘,更在于…