PyTorch使用教程(10)-torchinfo.summary网络结构可视化详细说明

1、基本介绍

torchinfo是一个为PyTorch用户量身定做的开源工具,其核心功能之一是summary函数。这个函数旨在简化模型的开发与调试流程,让模型架构一目了然。通过torchinfosummary函数,用户可以快速获取模型的详细结构和统计信息,如模型的层次结构、输入/输出维度、参数数量、多加操作(Mult-Adds)等关键信息。

2、安装

首先,你需要安装torchinfo库。可以通过pip进行安装:

pip install torchinfo

3、导入

安装完成后,需要在你的Python脚本中导入torchinfo模块:

from torchinfo import summary

4、函数原型定义

torchinfo的summary函数原型定义如下:

def summary(model: nn.Module, input_data: torch.Tensor | tuple[torch.Tensor, ...] | tuple[int, ...] | None = None, batch_dim: int = 0, col_widths: tuple[int, ...] | None = None, col_names: tuple[str, ...] | None = None, device: str | torch.device | None = None, dtypes: tuple[torch.dtype, ...] | None = None, verbose: int = 1, **kwargs)

参数说明

  • model: 要分析的PyTorch模型,必须是torch.nn.Module的实例。
  • input_data: 用于模型前向传播的输入数据。它可以是一个torch.Tensor对象,也可以是一个包含多个输入张量的元组。此外,还可以提供一个表示输入尺寸的元组,例如(batch_size, channels, height, width)。
  • batch_dim: 指定输入张量中哪个维度是批量大小(batch size)。默认为0。
  • col_widths: 指定输出列宽的元组。如果未指定,则自动计算列宽以适应输出。
  • col_names: 指定输出列名的元组。如果未指定,则使用默认列名。
  • device: 指定模型运行的设备(如’cpu’或’cuda’)。如果未指定,则自动选择。
  • dtypes: 指定输入张量的数据类型。如果未指定,则自动推断。
  • verbose: 控制输出信息的详细程度。默认为1,表示输出基本信息。设置为2或更高可以获得更详细的输出。
  • kwargs: 其他关键字参数,可以传递给模型的前向传播函数。

5、使用方法

下面通过几个示例来展示如何使用torchinfo的summary函数。
5.1 使用预定义模型
首先,我们使用PyTorch预定义的模型(如torchvision.models.resnet50)来展示如何使用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary
# 定义模型
model = models.resnet18(pretrained=False)# 使用summary函数打印模型概况
summary(model, input_size=(1, 3, 224, 224))

在这个示例中,我们加载了一个未预训练的ResNet50模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, channels, height, width)。
在这里插入图片描述

5.2 使用自定义模型
接下来,我们定义一个简单的自定义模型,并使用summary函数打印其概况。

import torch
import torch.nn as nn
from torchinfo import summary# 定义一个简单的两层全连接神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)self.relu = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 创建模型实例
model = SimpleModel()# 使用summary函数打印模型概况
summary(model, input_size=(100,))

在这个示例中,我们定义了一个简单的两层全连接神经网络模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, features)。由于我们的模型是一个全连接层,所以我们只指定了特征数量。
在这里插入图片描述

5.3 使用自定义输入数据

有时候,可能想要使用实际的输入数据来查看模型的概况。下面是一个示例,展示了如何使用自定义输入数据来调用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary# 定义模型
model = models.resnet50(pretrained=False)# 创建自定义输入数据
input_data = torch.randn(1, 3, 224, 224)  # batch_size=1, channels=3, height=224, width=224# 使用summary函数打印模型概况
summary(model, input_data=input_data)

在这个示例中,我们创建了一个形状为(1, 3, 224, 224)的随机张量作为输入数据,并使用summary函数打印了模型的概况。注意,这里我们使用input_data参数而不是input_size参数来指定输入数据。

5.4 调整输出格式
torchinfo允许通过col_widths和col_names参数来调整输出的格式。下面是一个示例,展示了如何自定义输出列宽和列名。

import torch
import torchvision.models as models
from torchinfo import summary# 定义模型
model = models.resnet50(pretrained=False)# 使用summary函数打印模型概况,并自定义输出列宽和列名
summary(model, input_size=(3, 224, 224), col_widths=(30, 30, 20, 20),col_names=('input_size', 'output_size', 'kernel_size', 'num_params'))

在这个示例中,我们自定义了输出列宽和列名。col_widths参数指定了每列的宽度(以字符为单位),而col_names参数指定了每列的列名。这样,就可以根据需要来调整输出的格式了。

6、小结

torchinfo的summary函数是一个强大的工具,可以方便地查看PyTorch模型的结构和参数数量。通过本文的介绍,应该已经掌握了如何使用summary函数来打印模型的概况。无论使用预定义模型还是自定义模型,无论是使用输入尺寸还是自定义输入数据,torchinfo都能提供详细而清晰的输出信息。希望这篇文章能对你有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/4972.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【22】Word:小李-高新技术企业政策❗

目录 题目​ NO1.2 NO3 NO4 NO5.6 NO7.8 NO9.10 若文章中存在删除空白行等要求,可以到最后来完成。注意最后一定要检查此部分!注意:大多是和事例一样即可,不用一摸一样,但也不要差太多。 题目 NO1.2 F12Fn&a…

TDengine 做 Apache SuperSet 数据源

‌Apache Superset‌ 是一个现代的企业级商业智能(BI)Web 应用程序,主要用于数据探索和可视化。它由 Apache 软件基金会支持,是一个开源项目,它拥有活跃的社区和丰富的生态系统。Apache Superset 提供了直观的用户界面…

Python----Python高级(文件操作open,os模块对于文件操作,shutil模块 )

一、文件处理 1.1、文件操作的重要性和应用场景 1.1.1、重要性 数据持久化: 文件是存储数据的一种非常基本且重要的方式。通过文件,我们可 以将程序运行时产生的数据永久保存下来,以便将来使用。 跨平台兼容性: 文件是一种通用…

STM32单片机:GPIO模式

GPIO有八种工作模式,分别是推挽输出、开漏输出、复合推挽输出、复合开漏输出、模拟输入、上拉输入、下拉输入、浮空输入。 在了解这些之前,我们先来看一下GPIO口内部的结构: I/O引脚一般工作电压为3.3V,在它边的两个二极管起到保…

[Qt]事件-鼠标事件、键盘事件、定时器事件、窗口改变事件、事件分发器与事件过滤器

目录 前言:Qt与操作系统的关系 一、Qt事件 1.事件介绍 2.事件的表现形式 常见的Qt事件: 常见的事件描述: 3.事件的处理方式 处理鼠标进入和离开事件案例 控件添加到对象树底层原理 二、鼠标事件 1.鼠标按下和释放事件(单击&#x…

Linux下MySQL的简单使用

Linux下MySQL的简单使用 导语MySQL安装与配置 MySQL安装密码设置 MySQL管理 命令 myisamchkmysql其他 常见操作 C语言访问MYSQL 连接例程错误处理使用SQL 总结参考文献 导语 这一章是MySQL的使用,一些常用的MySQL语句属于本科阶段内容,然后是C语言和M…

ElasticSearch索引别名的应用

个人博客:无奈何杨(wnhyang) 个人语雀:wnhyang 共享语雀:在线知识共享 Github:wnhyang - Overview Elasticsearch 索引别名是一种极为灵活且强大的功能,它允许用户为一个或多个索引创建逻辑上…

火狐浏览器Firefox一些配置

没想到还会开这个…都是Ubuntu的错 一些个人习惯吧 标签页设置 常规-标签页 1.按最近使用顺序切换标签页 2.打开新标签而非新窗口(讨厌好多窗口) 3.打开新链接不直接切换过去(很打断思路诶) 4.关闭多个标签页时不向我确认 启动…

数据结构-队列

目录 前言一、队列及其抽象数据类型1.1 队列的基本概念1.2 队列的抽象数据类型 二、队列的实现2.1 顺序表示2.1.1 结构定义2.1.2 基本操作的实现 2.2 链式表示2.2.1 结构定义2.2.2 基本操作的实现 总结 前言 本篇文章介绍队列的基础知识,包括队列的抽象数据类型以及…

STM32-串口-UART-Asynchronous

一,发送数据 #include "stdio.h" uint8_t hello[]"Hello,blocking\r\n"; HAL_UART_Transmit(&huart1,hello,sizeof(hello),500); 二,MicroLIB-printf(" hello\r\n") #include "stdio.h" #ifdef __GNUC…

深度学习 DAY2:Transformer(一部分)

前言 Transformer是一种用于自然语言处理(NLP)和其他序列到序列(sequence-to-sequence)任务的深度学习模型架构,它在2017年由Vaswani等人首次提出。Transformer架构引入了自注意力机制(self-attention mech…

《目标检测数据集下载地址》

一、引言 在计算机视觉的广袤领域中,目标检测宛如一颗璀璨的明星,占据着举足轻重的地位。它宛如赋予计算机一双锐利的 “眼睛”,使其能够精准识别图像或视频中的各类目标,并确定其位置,以边界框的形式清晰呈现。这项技…

题解 CodeForces 1037D Valid BFS? 三种解法 C++

题目传送门 Problem - 1037D - Codeforceshttps://codeforces.com/problemset/problem/1037/Dhttps://codeforces.com/problemset/problem/1037/Dhttps://codeforces.com/problemset/problem/1037/Dhttps://codeforces.com/problemset/problem/1037/Dhttps://codeforces.com/p…

2024微短剧行业生态洞察报告汇总PDF洞察(附原数据表)

原文链接: https://tecdat.cn/?p39072 本报告合集洞察从多个维度全面解读微短剧行业。在行业发展层面,市场规模与用户规模双增长,创造大量高收入就业岗位并带动产业链升级。内容创作上,精品化、品牌化趋势凸显,题材走…

HTML<img>标签

例子 如何插入图片&#xff1a; <img src"img_girl.jpg" alt"Girl in a jacket" width"500" height"600"> 下面有更多“自己尝试”的示例。 定义和用法 该<img>标签用于在 HTML 页面中嵌入图像。 从技术上讲&#x…

故障诊断 | BWO白鲸算法优化KELM故障诊断(Matlab)

目录 效果一览文章概述BWO白鲸算法优化KELM故障诊断一、引言1.1、研究背景及意义1.2、故障诊断技术的现状1.3、研究目的与内容二、KELM基本理论2.1、KELM模型简介2.2、核函数的选择2.3、KELM在故障诊断中的应用三、BWO白鲸优化算法3.1、BWO算法基本原理3.2、BWO算法的特点3.3、…

apisix的authz-casbin

目录 1、apisix的auth-casbin官方介绍 2、casbin介绍和使用 2.1基本知识&#xff1a; 2.2使用例子 3、配置插件 4、postman调用 5、auth-casbin的坑 1、apisix的auth-casbin官方介绍 authz-casbin | Apache APISIX -- Cloud-Native API Gateway 2、casbin介绍和使用 c…

基于python+Django+mysql鲜花水果销售商城网站系统设计与实现

博主介绍&#xff1a;黄菊华老师《Vue.js入门与商城开发实战》《微信小程序商城开发》图书作者&#xff0c;CSDN博客专家&#xff0c;在线教育专家&#xff0c;CSDN钻石讲师&#xff1b;专注大学生毕业设计教育、辅导。 所有项目都配有从入门到精通的基础知识视频课程&#xff…

【大模型】ChatGPT 高效处理图片技巧使用详解

目录 一、前言 二、ChatGPT 4 图片处理介绍 2.1 ChatGPT 4 图片处理概述 2.1.1 图像识别与分类 2.1.2 图像搜索 2.1.3 图像生成 2.1.4 多模态理解 2.1.5 细粒度图像识别 2.1.6 生成式图像任务处理 2.1.7 图像与文本互动 2.2 ChatGPT 4 图片处理应用场景 三、文生图操…

【0x0052】HCI_Write_Extended_Inquiry_Response命令详解

目录 一、命令概述 二、命令格式及参数 2.1. HCI_Write_Extended_Inquiry_Response命令格式 2.2. FEC_Required 2.3. Extended_Inquiry_Response 三、生成事件及参数 3.1. HCI_Command_Complete 事件 3.2. Status 四、命令执行流程 4.1. 命令准备阶段(主机端) 4.2…