机器学习复习(3)——分类神经网络与drop out

完整的神经网络

以分类任务为例,神经网络一般包括backbone和head(计算机视觉领域)

下面的BasicBlock不是一个标准的backbone,标准的应该是复杂的CNNs构成的

Classfier是一个标准的head,其中output_dim表示分类类别,一般写作num_classes

import torch  # 导入 torch 库
import torch.nn as nn  # 导入 torch 的神经网络模块
import torch.nn.functional as F  # 导入 torch 的函数式接口# 定义一个基础的神经网络模块
class BasicBlock(nn.Module):  # 继承自 torch 的 Module 类def __init__(self, input_dim, output_dim):super(BasicBlock, self).__init__()  # 初始化父类# 构建一个序列模块,包含一个线性层和一个 ReLU 激活函数self.block = nn.Sequential(
# 线性层,输入维度为 input_dim,输出维度为 output_dimnn.Linear(input_dim, output_dim),  nn.ReLU(),  # ReLU 激活函数)def forward(self, x):x = self.block(x)  # 将输入数据 x 通过定义的序列模块return x  # 返回模块的输出# 定义一个分类器神经网络
class Classifier(nn.Module):  # 继承自 torch 的 Module 类def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):super(Classifier, self).__init__()  # 初始化父类# 构建一个序列模块,包含若干个 BasicBlock 和一个线性输出层self.fc = nn.Sequential(
# 第一个 BasicBlock,将输入维度转换为隐藏层维度BasicBlock(input_dim, hidden_dim),  
# 根据 hidden_layers 数量添加多个 BasicBlock*[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],  
# 线性输出层,将隐藏层维度转换为输出维度nn.Linear(hidden_dim, output_dim)  )def forward(self, x):x = self.fc(x)  # 将输入数据 x 通过定义的序列模块return x  # 返回模块的输出

对 *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)]的一个补充解释,“*”代表解压列表,例如A=[a,b,c],那么f(*A)=f(a,b,c)

在这里的具体意义是“便于控制隐藏层数量”,而其中的"_"代表不希望在循环中使用变量,这是一种通用的惯例,表明循环的目的不是对每个元素进行操作,而是为了重复某个操作特定次数。如果hidden_layers=3,这里的等价含义就是BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),——连续出现三次

dropout

Dropout层在神经网络层当中是用来干什么的呢?它是一种可以用于减少神经网络过拟合的结构。

如上图我们定义的网络,一共有四个输入x_i,一个输出y。Dropout则是在每一个batch的训练当中随机减掉一些神经元,而作为编程者,我们可以设定每一层dropout(将神经元去除的的多少)的概率,在设定之后,就可以得到第一个batch进行训练的结果:  

从上图我们可以看到一些神经元之间断开了连接,因此它们被dropout了!dropout顾名思义就是被拿掉的意思,正因为我们在神经网络当中拿掉了一些神经元,所以才叫做dropout层。
在进行第一个batch的训练时,有以下步骤:

  • 设定每一个神经网络层进行dropout的概率
  • 根据相应的概率拿掉一部分的神经元,然后开始训练,更新没有被拿掉神经元以及权重的参数,将其保留
  • 参数全部更新之后,又重新根据相应的概率拿掉一部分神经元,然后开始训练,如果新用于训练的神经元已经在第一次当中训练过,那么我们继续更新它的参数。而第二次被剪掉的神经元,同时第一次已经更新过参数的,我们保留它的权重,不做修改,直到第n次batch进行dropout时没有将其删除。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/250504.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW传感器通用实验平台

LabVIEW传感器通用实验平台 介绍了基于LabVIEW的传感器实验平台的开发。该平台利用LabVIEW图形化编程语言和多参量数据采集卡,提供了一个交互性好、可扩充性强、使用灵活方便的传感器技术实验环境。 系统由硬件和软件两部分组成。硬件部分主要包括多通道数据采集卡…

go grpc高级用法

文章目录 错误处理常规用法进阶用法原理 多路复用元数据负载均衡压缩数据 错误处理 gRPC 一般不在 message 中定义错误。毕竟每个 gRPC 服务本身就带一个 error 的返回值,这是用来传输错误的专用通道。gRPC 中所有的错误返回都应该是 nil 或者 由 status.Status 产…

如何在 Golang 中使用 crypto/ed25519 进行数字签名和验证

如何在 Golang 中使用 crypto/ed25519 进行数字签名和验证 引言crypto/ed25519 算法简介环境搭建和准备工作生成密钥对进行数字签名 验证签名实际应用场景案例总结 引言 在当今数字化时代,网络安全显得尤为重要。无论是在网上进行交易、签署合同,还是发…

笔记---容斥原理

AcWing,890.能被整除的数 给定一个整数 n n n 和 m m m 个不同的质数 p 1 , p 2 , … , p m p_{1},p_{2},…,p_{m} p1​,p2​,…,pm​。 请你求出 1 ∼ n 1∼n 1∼n 中能被 p 1 , p 2 , … , p m p_{1},p_{2},…,p_{m} p1​,p2​,…,pm​ 中的至少一个数整除的整数有多少…

element-ui link 组件源码分享

link 组件的 api 涉及的内容不是很多,源码部分的内容也相对较简单,下面从以下这三个方面来讲解: 一、组件结构 1.1 组件结构如下图: 二、组件属性 2.1 组件主要有 type、underline、disabled、href、icon 这些属性,…

Golang `crypto/hmac` 实战指南:代码示例与最佳实践

Golang crypto/hmac 实战指南:代码示例与最佳实践 引言HMAC 的基础知识1. HMAC 的工作原理2. HMAC 的应用场景 Golang crypto/hmac 库概览1. 导入和基本用法2. HMAC 的生成和验证3. crypto/hmac 的特性 实战代码示例示例 1: 基本的 HMAC 生成示例 2: 验证消息完整性…

C语言:内存函数(memcpy memmove memset memcmp使用)

和黛玉学编程呀------------- 后续更新的节奏就快啦 memcpy使用和模拟实现 使用 void * memcpy ( void * destination, const void * source, size_t num ) 1.函数memcpy从source的位置开始向后复制num个字节的数据到destination指向的内存位置。 2.这个函数在遇到 \0 的时候…

STM32 有源蜂鸣器

模块介绍: 结构:有源蜂鸣器通常由一个振膜和一个驱动电路组成。振膜是负责产生声音的部分,而驱动电路则负责控制振荡频率和幅度。 工作原理:有源蜂鸣器的驱动电路会向振膜施加电压,使其振动产生声音。驱动电路可以根据输入信号的…

阿里云a10GPU,centos7,cuda11.2环境配置

Anaconda3-2022.05-Linux-x86_64.sh gcc升级 centos7升级gcc至8.2_centos7 yum gcc8.2.0-CSDN博客 paddlepaddle python -m pip install paddlepaddle-gpu2.5.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 报错 ImportError: libssl.so…

基于Springboot的高校心理教育辅导设计与实现(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的高校心理教育辅导设计与实现(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,…

flask基于django大数据的证券股票分析系统python可视化大屏

证券分析系统采用B/S架构,数据库是MySQL。网站的搭建与开发采用了先进的Python进行编写,使用了Django框架。该系统从两个对象:由管理员和用户来对系统进行设计构建。主要功能包括:个人信息修改,对股票信息、股票买入、…

深度学习——pycharm远程连接

目录 远程环境配置本地环境配置(注意看假设!!!这是很多博客里没写的)步骤1步骤2步骤2.1 配置Connection步骤2.2 配置Mappings 步骤3 配置本地项目的远程解释器技巧1 pycharm中远程终端连接技巧2 远程目录技巧3 上传代码文件技巧4 …

【无标题】yarn报错 “https://registry.npm.taobao.org/...: certificate has expired“如何处理

前言 今天在jenkins打包项目时yarn打包报错,查看log发现npm淘宝镜像报错 原因 在 1 月 22 日,淘宝原镜像域名(registry.npm.taobao.org)的 HTTPS 证书正式到期。如果想要继续使用,需要将 npm 源切换到新的源&#…

【LVGL环境搭建】

LVGL环境搭建 win模拟器环境搭建一.二.三.四.五. Ubuntu模拟器环境搭建一. 前置准备二. 下载LVGL Source code:三. 安装sdl2:四. 开启VScode执行五. 安装扩展套件六. 按F5执行七. 执行结果 win模拟器环境搭建 一. 二. 三. 四. 五. Ubuntu模拟器环境…

深入理解指针(3)

⽬录 1. 字符指针变量 2. 数组指针变量 3. ⼆维数组传参的本质 4. 函数指针变量 5. 函数指针数组 6. 转移表 1. 字符指针变量 在指针的类型中我们知道有⼀种指针类型为字符指针 char* ; ⼀般使⽤: int main() {char ch w;char *pc &ch;*pc w;return 0; } 还有…

Unity | YooAssetV2.1.0 + HybridCLR热更新

目录 一、项目更改 二、使用YooAsset热更 1.资源配置 2.资源构建 3.将两个文件夹下的资源上传CDN服务器 4.修改代码 5.运行效果 本文记录利用YooAssetHybridCLR来进行资源和dll的更新。YooAsset使用的是新版V2.1.0。相比于旧版,dll(原生文件)和资源要建两个p…

opencv0014 索贝尔(sobel)算子

前面学习的滤波器主要是用来模糊图像,今天一起来了解关于边缘识别的滤波吧!嘿嘿 边缘 边缘是像素值发生跃迁的位置,是图像的显著特征之一,在图像特征提取,对象检测,模式识别等方面都有重要的作用。 人眼如…

【课程作业_01】国科大2023模式识别与机器学习实践作业

国科大2023模式识别与机器学习实践作业 作业内容 从四类方法中选三类方法,从选定的每类方法中 ,各选一种具体的方法,从给定的数据集中选一 个数据集(MNIST,CIFAR-10,电信用户流失数据集 )对这…

SpringBoot+Redis如何实现用户输入错误密码后限制登录(含源码)

点击下载《SpringBootRedis如何实现用户输入错误密码后限制登录(含源码)》 1. 引言 在当今的网络环境中,保障用户账户的安全性是非常重要的。为了防止暴力破解和恶意攻击,我们需要在用户尝试登录失败一定次数后限制其登录。这不…

51单片机学习笔记 --步进电机驱动说明

文章目录 工作原理代码编写驱动方式全步进驱动半步进驱动微步进驱动 工作原理 工作原理简要说明,和单片机一起配合使用的步进电机多为28BYJ28 五线四相步进电机,配合ULN2003驱动板进行控制,如图所示,对于扭矩、精度要求较高的还有…