pytorch实现胶囊网络(capsulenet)

胶囊网络在hinton刚提出来的时候小热过一段时间,之后热度并没有维持多久。vision transformer之后基本少有人问津了。不过这个模型思路挺独特的,值得研究一下。

这个模型的提出是为了解决CNN模型学习到的特征之间没有空间上的关系,从而对于各种变换不鲁棒的缺点。

模型的整体思路如下:

1,胶囊:

抛开论文里花哨的描述,胶囊其实就是特征图上比点更大的单元,本质上我觉得类似transformer的patch。当然也有一定的差别,因为后续要用动态路由更新胶囊,所以胶囊必须要是向量,而不是标量。

2,动态路由:

由于pooling会导致信息丢失,作者使用动态路由来连接两个胶囊层,并更新胶囊。

同时,动态路由也能建立不同层胶囊(特征)在空间上的相对关系。

由于胶囊其实是向量,动态路由算法会根据这些向量的相似性(点积)和一致性(加权)来决定信息传递的路径。

3,整体结构:

1)卷积层

2)PrimaryCaps层:这层的作用就是把卷积特征转变成胶囊的形式

3)DigitCaps层:用动态路由迭代生成高层的胶囊。

4)解码器

4,loss

胶囊网络的损失函数主要由两部分组成:间隔损失(Margin Loss)和重构损失。

在计算间隔损失时,会使用一个阈值(通常设置为0.9和0.1)来区分正样本和负样本。如果某一类的胶囊输出向量的模长大于阈值m+(正样本阈值,例如0.9),则认为该类存在,并将其视为正样本;反之,如果输出向量的模长小于阈值m-(负样本阈值,例如0.1),则认为该类不存在,将其视为负样本。

重构损失的计算通常基于原始输入数据与重构数据之间的差异,例如使用均方误差(MSE)来衡量这种差异。

如果站在2024年的如今再来看当初的设计,其实胶囊的思路还是很像后来的transformer的,有点殊途同归的感觉。


pytorch实现:

1,实现初始胶囊

首先是会用到的压缩函数,压缩函数的作用是将向量的长度压缩到0和1之间,同时保留向量的方向不变。

公式:

def squash(inputs, axis=-1):norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)scale = norm**2 / (1 + norm**2 + 1e-8) / (norm + 1e-8)return scale * inputs

初始胶囊,这一层的作用是将卷积特征转换为胶囊的形式。

class PrimaryCapsule(nn.Module):def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):super(PrimaryCapsule, self).__init__()self.dim_caps = dim_capsself.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x):outputs = self.conv2d(x)outputs = outputs.reshape(x.size(0), -1, self.dim_caps)return squash(outputs)

2,实现胶囊层

路由算法

这个伪代码初看起来挺乱的,我翻译成人话如下:

首先,每一次迭代由两层胶囊层做点积后再通过softmax计算出耦合系数c。

耦合系数和下层胶囊的预测计算加权和,这是个投票的过程。

再通过压缩函数,就得到了本层的胶囊v。

因为这是个迭代的过程,需要不断更新耦合系数C。

新的耦合系数由两层胶囊之间的相似度决定。


具体实现中,会对低层胶囊先做一个变换,也就是下面代码里的weight。这个权重矩阵代表的是对下层胶囊的变化,变换之后的结果Ui|j用论文里的话说叫做“prediction vectors”。

胶囊层代码:

class DenseCapsule(nn.Module):def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):super(DenseCapsule, self).__init__()self.in_num_caps = in_num_capsself.in_dim_caps = in_dim_capsself.out_num_caps = out_num_capsself.out_dim_caps = out_dim_capsself.routings = routings #路由的迭代次数#初始化self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))def forward(self, x):u_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)#从当前计算图中分离出x_hat,这样在后续的反向传播中不会计算其梯度 u_hat_detached = u_hat.detach()b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps).cuda()#路由算法for i in range(self.routings):c = F.softmax(b, dim=1)if i == self.routings - 1:v = squash(torch.sum(c[:, :, :, None] * u_hat, dim=-2, keepdim=True))else:v = squash(torch.sum(c[:, :, :, None] * u_hat_detached, dim=-2, keepdim=True))b = b + torch.sum(v * u_hat_detached, dim=-1)return torch.squeeze(v, dim=-2)

需要将的是u_hat_detached = u_hat.detach()这一步。将u_hat从计算图中分离出来的目的,是为了防止迭代过程中梯度不断累积,导致梯度过大。所以我们可以在后续的路由算法中看出,只有在最后一次计算路由时使用了u_hat,之前的迭代中都是使用的u_hat_detached。从而让整个路由过程中梯度只更新一次。

3,损失函数

def caps_loss(y_true, y_pred, x, x_recon, lambd=0.5):L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2L_margin = L.sum(dim=1).mean()L_recon = nn.MSELoss()(x_recon, x)return L_margin + lambd * L_recon

4,整体模型

模型返回两个值,一个是预测的概率,一个是重建的图像。这两个值会分别用来计算间隔损失和重构损失。

class CapsuleNet(nn.Module):def __init__(self, input_size, classes, routings):super(CapsuleNet, self).__init__()self.input_size = input_sizeself.classes = classesself.routings = routingsself.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,out_num_caps=classes, out_dim_caps=16, routings=routings)self.decoder = nn.Sequential(nn.Linear(16*classes, 512),nn.ReLU(inplace=True),nn.Linear(512, 1024),nn.ReLU(inplace=True),nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),nn.Sigmoid())self.relu = nn.ReLU()def forward(self, x, y=None):x = self.relu(self.conv1(x))x = self.primarycaps(x)x = self.digitcaps(x)length = x.norm(dim=-1)if y is None:index = length.max(dim=1)[1]y = torch.zeros(length.size()).scatter_(1, index.view(-1, 1), 1.)reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))return length, reconstruction.view(-1, *self.input_size)

5,注意事项:

1)one-hot

在重建过程中使用的标签y是one-hot形式的,因此在训练和测试时需要加上这行代码,转换一下

targets = F.one_hot(targets, num_classes=classes).to(device)

2) loss

训练和测试时的loss设置如下

loss = caps_loss(y_true=targets,y_pred=y_pred,x=imgs,x_recon=x_recon,lambd=0.5)loss = loss.to(device)

其中lambd这个系数决定的是重构损失所占的比例 loss=margin_loss+lambd*recon_loss

总结:

胶囊网络分类结果不算差,在我的一些任务中train from scratch的胶囊网络就超越了imagenet1k上预训练过再finetune的vit。也超过了无预训练的VGG和resnet。(但是不如预训练过的vgg和resnet)。

这样的表现放在2017年已经很能打了,没火的原因我感觉有3个:

首先,由于胶囊网络迭代过程需要多次完整的特征图点乘特征图,所以内存消耗和时间消耗都是巨大的。我跑256的图时,24g显存的4090也只能把batch设置成5,运行速度非常慢。放在2017年,只能用1080ti来跑这个模型,简直折磨。(我2018年时也试过这个模型,训练都是按周算的,这谁愿意用啊)

另外一个原因可能是它的改进潜力不大。例如vit的核心机制是自注意力,注意力大家都玩出花来了,各种改进思路都很好借鉴。虽然vit效果很一般,但是后续的改进模型一个比一个厉害。而胶囊网络的核心路由算法想要创新就比较难。

最后还有一点就是原作者没放出胶囊网络在imagenet上的预训练模型。这个对模型热度的影响其实挺大的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/305581.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源监控zabbix对接可视化工具grafana教程

今天要给大家介绍的是开源监控工具zabbix对接可视化工具grafana问题。 有一定运维经验的小伙伴大抵都或多或少使用过、至少也听说过开源监控工具zabbix,更进一步的小伙伴可能知道zabbix在数据呈现方面有着明显的短板,因此需要搭配使用第三方的可视化工具…

背 单 词 (考研词汇闪过)

单词: 买考研词汇闪过 研究艾宾浩斯遗忘曲线 https://www.bilibili.com/video/BV18Y4y1h7YR/?spm_id_from333.337.search-card.all.click&vd_source5cbefe6dd70d6d84830a5891ceab2bf9 单词方法 闪记背两排(5min)重复一遍(2mi…

4-云原生监控体系-Grafana-基本使用

1. 介绍 使用Grafana,您可以通过漂亮、灵活的仪表板创建、探索和共享所有数据。查询、可视化、提醒和理解您的数据,无论数据存储在何处。 图片出处: https://grafana.com/grafana/ 官方网站 2. 界面介绍 Connections 可以配置数据源&#x…

C语言 | Leetcode C语言题解之第21题合并两个有序链表

题目: 题解: /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/typedef struct ListNode ListNode; struct ListNode* mergeTwoLists(struct ListNode* list1, struct ListNode* list2) {/…

Python+Appium自动化测试(ios+Android)

一、软件安装 安装清单: JDKPythonnode.jsandroid-sdk(作者通过Android Studio安装)iOS-deploybrewlibimobiledevice依赖库ideviceinstallercarthage依赖库 appium-doctor(安装后可在命令行中通过命令:appium-doctor检查还少啥) WebDriverAg…

Spring入门——IoC控制反转

前言 本博客是博主用于学习Spring的博客,如果疏忽出现错误,还望各位指正。 Bean Bean的概念 Bean的详解,可以参考这篇文章。 【SpringBoot】Bean 是什么?_sptingboot bean 是什么-CSDN博客 Bean 作为 Spring框架面试中不可或…

微服务学习(黑马)

学习黑马的微服务课程的笔记 导学 微服务架构 认识微服务 SpringCloud spring.io/projects/spring-cloud/ 服务拆分和远程调用 根据订单id查询订单功能 存在的问题 硬编码 eureka注册中心 搭建eureka 服务注册 在order-service中完成服务拉取 Ribbon负载均衡 Nacos注册中心…

政安晨:【Keras机器学习实践要点】(二十五)—— 使用 EANet(外部注意力转换器)进行图像分类

目录 简介 设置 准备数据 配置超参数 使用数据增强 实施补丁提取和编码层 实施外部关注模块 实施 MLP 模块 执行变压器模块 实施 EANet 模式 培训 CIFAR-100 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机…

Flutter - flutter_gen 资源管理

引言: 在开发 Flutter 应用时,我们经常需要使用各种静态资源,如图片、字体和音频等。如何有效地管理和加载这些资源呢?本篇博客将以图片为例带你解密 Flutter 项目中是如何管理资源地。 assets 加载资源 具体文件名引入 在工程…

Leetcode 239. 滑动窗口最大值和Leetcode 347. 前 K 个高频元素

目录标题 Leetcode 239. 滑动窗口最大值题目描述C语言代码和题解解题思路 Leetcode 347. 前 K 个高频元素题目描述C语言题解和思路解题思路 Leetcode 239. 滑动窗口最大值 题目描述 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最…

[Kubernetes[K8S]集群:Slaver从节点初始化和Join]:添加到主节点集群内

文章目录 操作流程:上篇主节初始化地址:前置:Docker和K8S安装版本匹配查看0.1:安装指定docker版本 **[1 — 8] ** [ 这些步骤主从节点前置操作一样的 ]一:主节点操作 查看主机域名->编辑域名->域名配置二&#x…

STM32学习和实践笔记(9): 使用位带操作实现LED闪的实验

控制GPIO的那些寄存器,都在位带区。 根据上一篇讲的原理,要想每次只操作这些寄存的某一个bit而不影响别的bit,可以使用与这些bit相对应的位带别名区。 因此,在使用GPIO的位带操作之前,先要按上篇讲的原理&#xff0c…

通过前缀和来看golang的acm输入输出

前缀和 问题引入 package mainimport ("fmt" )func main() {var n, q, l, r intfmt.Scan(&n, &q)a : make([]int, n)ap : make([]int64, n 1)ap[0] 0for i : 0; i < n; i {fmt.Scan(&a[i])ap[i 1] ap[i] int64(a[i])}for j : 0; j < q; j {f…

远程桌面无法连接怎么办?

远程桌面无法连接是指在尝试使用远程桌面功能时出现连接失败的情况。这种问题可能会给工作和生活带来极大的不便&#xff0c;因此我们需要寻找解决办法。在讨论解决方案之前&#xff0c;我们先来了解一下【天联】组网的优势。 【天联】组网的优势有很多。它能够解决复杂网络环境…

[大模型]Qwen1.5-7B-Chat-GPTQ-Int4 部署环境

Qwen1.5-7B-Chat-GPTQ-Int4 部署环境 说明 Qwen1.5-72b 版本有BF16、INT8、INT4三个版本&#xff0c;三个版本性能接近。由于BF16版本需要144GB的显存&#xff0c;让普通用户忘却止步&#xff0c;而INT4版本只需要48GB即可推理&#xff0c;给普通用户本地化部署创造了机会。&…

MySQL数据库的详解(1)

DDL&#xff08;数据库操作&#xff09; 查询 查询所有数据库&#xff1a;show databases;当前数据库&#xff1a;select database(); 创建 创建数据库&#xff1a;create database [ if not exists] 数据库名 ; 使用 使用数据库&#xff1a;use 数据库名 ; 删除 删除数…

Deblurring 3D Gaussian Splatting去模糊3D高斯溅射

Abstract 摘要 Recent studies in Radiance Fields have paved the robust way for novel view synthesis with their photorealistic rendering quality. Nevertheless, they usually employ neural networks and volumetric rendering, which are costly to train and impede…

如何将普通maven项目转为maven-web项目

文件-项目结构&#xff08;File-->Project Structure &#xff09; 模块-->learn&#xff08;moudle-->learn&#xff09; 选中需要添加web的moudle&#xff0c;点击加号&#xff0c;我得是learn&#xff0c;单击选中后进行下如图操作&#xff1a; 编辑路径 结果如下…

微信小程序自定义关闭按钮在弹窗下面的效果

效果图: 我之前用vant 的popup的弹窗写&#xff0c;会出现close图标移动到弹窗内容外部不可见。 自定义代码&#xff1a; popup.JS/*** 生命周期函数--监听页面初次渲染完成*/onReady() {//自定义弹窗 动态获取屏幕高度var that this;wx.getSystemInfo({success: (result) &…

4.2.k8s的pod-标签管理、镜像拉取策略、容器重启策略、资源限制、优雅终止

一、标签管理 1.标签在k8s中极其重要&#xff0c;大多数资源的相互关联就需要使用标签&#xff1b;也就是说&#xff0c;资源的相互关联大多数时候&#xff0c;是使用标签进行关联的&#xff1b; 2.其他作用&#xff0c;在k8s集群中&#xff0c;node节点的一些操作比如污点及污…