深度学习基础之参数量(3)

一般的CNN网络的参数量估计代码

class ResidualBlock(nn.Module):def __init__(self, in_planes, planes, norm_fn='group', stride=1):super(ResidualBlock, self).__init__()print(in_planes, planes, norm_fn, stride)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)self.relu = nn.ReLU(inplace=True)num_groups = planes // 8if norm_fn == 'group':self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)if not stride == 1:self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)elif norm_fn == 'batch':self.norm1 = nn.BatchNorm2d(planes)self.norm2 = nn.BatchNorm2d(planes)if not stride == 1:self.norm3 = nn.BatchNorm2d(planes)elif norm_fn == 'instance':self.norm1 = nn.InstanceNorm2d(planes)self.norm2 = nn.InstanceNorm2d(planes)if not stride == 1:self.norm3 = nn.InstanceNorm2d(planes)elif norm_fn == 'none':self.norm1 = nn.Sequential()self.norm2 = nn.Sequential()if not stride == 1:self.norm3 = nn.Sequential()if stride == 1:self.downsample = Noneelse:self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)def forward(self, x):print(x.shape)#exit()y = xy = self.relu(self.norm1(self.conv1(y)))y = self.relu(self.norm2(self.conv2(y)))if self.downsample is not None:x = self.downsample(x)return self.relu(x + y)R=ResidualBlock(384, 384, norm_fn='instance', stride=1)
summary(R.to("cuda" if torch.cuda.is_available() else "cpu"), (384, 32, 32))

transformer结构的参数量的估计结果

import torch
import torch.nn as nn
from thop import profile
from torchsummary import summary# 定义一个简单的Transformer模型
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.embedding = nn.Embedding(input_dim, hidden_dim)self.transformer_layers = nn.Transformer(d_model=hidden_dim,nhead=num_heads,num_encoder_layers=num_layers,num_decoder_layers=num_layers)self.fc = nn.Linear(hidden_dim, input_dim)def forward(self, src, tgt):src = self.embedding(src)tgt = self.embedding(tgt)output = self.transformer_layers(src, tgt)output = self.fc(output)return output# 创建Transformer模型实例
model2 = Transformer(input_dim=512, hidden_dim=512, num_heads=8, num_layers=6)# 使用thop进行FLOPS估算
flops, params = profile(model2, inputs=(torch.randint(0, 512, (128,)), torch.randint(0, 512, (64,))))
print(f"FLOPS: {flops / 1e9} G FLOPS")  # 打印FLOPS,以十亿FLOPS(GFLOPS)为单位# 计算参数量并打印
num_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {num_params}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/149199.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL——使用mysqldump备份与恢复数据

目录 1.mysqldump简介 2.mysqldump备份数据 2.1 备份所有数据库 2.2 备份一个/多个数据库 2.3 备份指定库中的指定表 3.mysqldump恢复数据 3.1 恢复数据库 3.2 恢复数据表 1.mysqldump简介 mysqldump命令可以将数据库中指定或所有的库、表导出为SQL脚本。表的结构和表中…

互联网Java工程师面试题·Elasticsearch 篇·第一弹

目录 1、elasticsearch 了解多少,说说你们公司 es 的集群架构,索引数据大小,分片有多少,以及一些调优手段 。 1.1 设计阶段调优 1.2 写入调优 1.3 查询调优 1.4 其他调优 2、elasticsearch 的倒排索引是什么 3、elastic…

使用Pytorch从零实现Vision Transformer

在这篇文章中,我们将基于Pytorch框架从头实现Vision Transformer模型,并附录完整代码。 Vision Transformer(ViT)是一种基于Transformer架构的深度学习模型,用于处理计算机视觉任务。它将图像分割成小的图像块(patches),然后使用Transformer编码器来处理这些图像块。V…

【单片机】16-LCD1602和12864和LCD9648显示器

1.LCD显示器相关背景 1.LCD简介 (1)显示器,常见显示器:电视,电脑 (2)LCD(Liquid Crystal Display),液晶显示器,原理介绍 (3&#xff…

十天学完基础数据结构-第九天(堆(Heap))

堆的基本概念 堆是一种特殊的树形数据结构,通常用于实现优先级队列。堆具有以下两个主要特点: 父节点的值始终大于或等于其子节点的值(最大堆),或者父节点的值始终小于或等于其子节点的值(最小堆&#xff…

【2023年11月第四版教材】第18章《项目绩效域》(合集篇)

第18章《项目绩效域》(合集篇) 1 章节内容2 干系人绩效域2.1 绩效要点2.2 执行效果检查2.3 与其他绩效域的相互作用 3 团队绩效域3.1 绩效要点3.2 与其他绩效域的相互作用3.3 执行效果检查3.4 开发方法和生命周期绩效域 4 绩效要点4.1 与其他绩效域的相互…

2023/10/4 QT实现TCP服务器客户端搭建

服务器端&#xff1a; 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTcpServer> #include <QTcpSocket> #include <QList> #include <QMessageBox> #include <QDebug>QT_BEGIN_NAMESPACE namespace Ui { cla…

十天学完基础数据结构-第八天(哈希表(Hash Table))

哈希表的基本概念 哈希表是一种数据结构&#xff0c;用于存储键值对。它的核心思想是将键通过哈希函数转化为索引&#xff0c;然后将值存储在该索引位置的数据结构中。 哈希函数的作用 哈希函数是哈希表的关键部分。它将输入&#xff08;键&#xff09;映射到哈希表的索引位…

Ubuntu使用cmake和vscode开发自己的项目,引用自己的头文件和openCV

创建文件夹 mkdir my_proj 继续创建include 和 src文件夹&#xff0c;形成如下的目录结构 用vscode打开项目 创建add.h #ifndef ADD_H #define ADD_Hint add(int numA, int numB);#endif add.cpp #include "add.h"int add(int numA, int numB) {return numA nu…

实战型开发2/3--架构设计

这里谈及在代码设计阶段以及重构阶段要考虑的架构方面问题&#xff0c;可以说是开发过程中的中层阶段&#xff1b; 主要是将 < the art of unix programming>< clean architecture>< the pragmatic programmer>< design patterns> 等几本书结合实践做…

[NSSRound#1 Basic]sql_by_sql - 二次注入+布尔盲注||sqlmap

进入注册界面后   假设sql&#xff1a;update user set password ‘’ where username ‘’ and password ‘’     此时如果我们注册的用户名是admin’–、admin’#、admin’–的话   update user set password ‘123’ where username ‘admin’#’ and passwor…

[架构之路-231]:计算机硬件与体系结构 - 性能评估汇总,性能优化加速比

目录 一、计算机体系结构 二、计算机性能评估 2.1 分类方法1 2.2 分类方法2 三、常见的专项性能测试工具 3.1 浮点运算性能&#xff08;FLOPS&#xff09; 3.2 综合理论性能法 3.3 历史基准测试&#xff08;跑分软件&#xff09;&#xff1a;通过运行典型的综合性的程序…

毕设-原创医疗预约挂号平台分享

医疗预约挂号平台 不是尚医通项目&#xff0c;先看项目质量&#xff08;有源码论文&#xff09; 项目链接&#xff1a;医疗预约挂号平台git地址 演示视频&#xff1a;医疗预约挂号平台 功能结构图 登录注册模块&#xff1a;该模块具体分为登录和注册两个功能&#xff0c;这些…

想要精通算法和SQL的成长之路 - 最长连续序列

想要精通算法和SQL的成长之路 - 最长连续序列 前言一. 最长连续序列1.1 并查集数据结构创建1.2 find 查找1.3 union 合并操作1.4 最终代码 前言 想要精通算法和SQL的成长之路 - 系列导航 并查集的运用 一. 最长连续序列 原题链接 这个题目&#xff0c;如何使用并查集是一个小难…

R语言教程课后习题答案(持续更新中~~)

R语言教程网址如下 https://www.math.pku.edu.cn/teachers/lidf/docs/Rbook/html/_Rbook/index.html 目录 source()函数可以运行保存在一个文本文件中的源程序 R向量下标和子集 数值型向量及其运算 日期功能 R因子类型 source()函数可以运行保存在一个文本文件中的源程序…

【C语言】动态通讯录(超详细)

通讯录是一个可以很好锻炼我们对结构体的使用&#xff0c;加深对结构体的理解&#xff0c;在为以后学习数据结构打下结实的基础 这里我们想设计一个有添加联系人&#xff0c;删除联系人&#xff0c;查找联系人&#xff0c;修改联系人&#xff0c;展示联系人&#xff0c;排序这几…

快速了解Spring Cache

SpringCache是一个框架&#xff0c;实现了基于注解的缓存功能&#xff0c;只需要简单的加一个注解&#xff0c;就可以实现缓存功能。 SpringCache提供了一层抽象&#xff0c;底层可以切换不同的缓存实现。例如&#xff1a; EHChche Redis Caffeine 常用注解&#xff1a; Enabl…

Vue中如何进行分布式路由配置与管理

Vue中的分布式路由配置与管理 随着现代Web应用程序的复杂性不断增加&#xff0c;分布式路由配置和管理成为了一个重要的主题。Vue.js作为一种流行的前端框架&#xff0c;提供了多种方法来管理Vue应用程序的路由。本文将深入探讨在Vue中如何进行分布式路由配置与管理&#xff0…

全志ARM926 Melis2.0系统的开发指引⑧

全志ARM926 Melis2.0系统的开发指引⑧ 编写目的12.5. 应用程序编写12.5.1. 简单应用编写12.5.1.1. 注册应用12.5.1.2. 创建管理窗口12.5.1.3. 实现管理窗口消息处理回调函数12.5.1.4. 创建图层12.5.1.5. 创建 framewin12.5.1.6. 实现 framewin 消息处理回调函数 -. 全志相关工具…

【BBC新闻文章分类】使用 TF 2.0和 LSTM 的文本分类

一、说明 NLP上的许多创新是如何将上下文添加到词向量中。常见的方法之一是使用递归神经网络