✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。
我是Srlua小谢,在这里我会分享我的知识和经验。🎥
希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮
记得先点赞👍后阅读哦~ 👏👏
📘📚 所属专栏:传知代码论文复现
欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙
目录
概述
Swin Transformer模型原理
1. Swin Transformer模型框架
2. W-MSA详解
3. SW-MSA详解
环境配置
安装必要的Python依赖:
数据准备
部分核心代码
训练过程
测试和评估
混淆矩阵
参考论文:
本文所有资源均可在该地址处获取。
概述
在计算机视觉领域,卷积神经网络(CNN)一直是构建模型的主流选择。自从AlexNet在ImageNet竞赛中取得了突破性的成绩后,CNN的结构不断演进,变得更庞大、更深入、更多样化。与此同时,自然语言处理领域的网络架构发展则呈现不同的轨迹,目前最流行的是Transformer模型。这种模型专为处理序列数据和转换任务而设计,以其能够捕捉数据中的长距离依赖关系而著称。Transformer在语言处理方面的显著成就激发了研究者探索其在计算机视觉领域的应用潜力,近期的研究表明,它在图像分类、目标检测、图像分割等任务上已经取得了令人鼓舞的成果。
实验得到该模型在图像分类、图像检测、目标检测有很好的效果。
上表列出了从 224^2 到 384^2 不同输入图像大小的 Swin Transformer 的性能。通常,输入分辨率越大,top-1 精度越高,但推理速度越慢。
Swin Transformer模型原理
1. Swin Transformer模型框架
首先,我们将图像送入一个称为Patch Partition的模块,该模块负责将图像分割成小块。然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样。
最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。
2. W-MSA详解
引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,对于feature map中的每个像素在Self-Attention计算过程中需要和所有的像素去计算。在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。
3. SW-MSA详解
采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了SW-MSA模块,即进行偏移的W-MSA。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了⌊ M/2 ⌋ 个像素)。比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。
环境配置
复现Swin Transformer需要首先准备pytorch环境。
安装必要的Python依赖:
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
数据准备
下载好数据集,代码中默认使用的是花分类数据集。当然也可以使用自定义的图像数据集,只要更改分类的数目和参数即可。需要确保数据集目录结构正确,以便Swin Transformer能正确读取数。
以下推荐的数据集文件目录:
├── flower_photos
│ ├── daisy
│ ├── sunflowers
│ └── tulips
├── weights
│ ├── model-0.pth
│ ├── model-1.pth
│ └── model-2.pth
├── pre_weights
│ ├── swin_large_patch4_window7_224_22k.pth
│ └── swin_tiny_patch4_window7_224.pth
├── labels
│ ├── train2017
│ └── val2017
├── class_indices.json
├── record.txt
└── requeirments.txt
部分核心代码
def __init__(self, dim, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioassert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
训练过程
此处可调整分类图像任务的种类数目,训练轮数,batch_size,训练图像,预训练模型等参数。
num_classes = 5
epochs = 10
batch_size = 8
lr = 0.0001
data_path = "flower_photos" # 修改为你的数据集路径
weights = './pre_weights/swin_tiny_patch4_window7_224.pth'
freeze_layers = False
通过8个线程进行模型训练,训练10轮因为数据集较大,耗时比较长有2个小时。查看结果发现只进行了几轮图像分类准确率在90%以上,效果较好:
本人用cpu跑的,最好用cuda跑。
输出的结果在weights中。
挑选准确最高,损失最小的模型model-x.pth进行消融实验即可。
测试和评估
采用model-9.pth模型进行蒲公英的图像分类预测,结果如下所示
这里是用花卉的数据集进行模型训练,可以自定义选择图像数据集进行训练。
混淆矩阵
查看图像分类的混淆矩阵,可以看出效果还是不错的:
参考论文:
- Swin-transformer 链接
- 官方代码 链接
希望对你有帮助!加油!
若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!