Interesting bug caused by getattr

题意:由 getattr 引起的有趣的 bug

问题背景:

I try to train 8 CNN models with the same structures simultaneously. After training a model on a batch, I need to synchronize the weights of the feature extraction layers in other 7 models.

我尝试同时训练8个具有相同结构的卷积神经网络(CNN)模型。在对一个批次的数据训练一个模型后,我需要同步其他7个模型中特征提取层的权重。

This is the model:        这是模型

class GNet(nn.Module):def __init__(self, dim_output, dropout=0.5):super(GNet, self).__init__()self.out_dim = dim_output# Load the pretrained AlexNet modelalexnet = models.alexnet(pretrained=True)self.pre_filtering = nn.Sequential(alexnet.features[:4])# Set requires_grad to False for all parameters in the pre_filtering networkfor param in self.pre_filtering.parameters():param.requires_grad = False# construct the feature extractor# every intermediate feature will be fed to the feature extractor# res: 25 x 25self.feat_ex1 = nn.Conv2d(192, 128, kernel_size=3, stride=1)# res: 25 x 25self.feat_ex2 = nn.Sequential(nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))# res: 25 x 25self.feat_ex3 = nn.Sequential(nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))# res: 13 x 13self.feat_ex4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))# res: 13 x 13self.feat_ex5 = nn.Sequential(nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))# res: 13 x 13self.feat_ex6 = nn.Sequential(nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))# res: 13 x 13self.feat_ex7 = nn.Sequential(nn.BatchNorm2d(128),nn.Dropout(p=dropout),nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1))# define the flexible pooling field of each layer# use a full convolution layer here to perform flexible poolingself.fpf13 = nn.Conv2d(in_channels=448, out_channels=448, kernel_size=13, groups=448)self.fpf25 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=25, groups=384)self.linears = {}for i in range(self.out_dim):self.linears[f'linear_{i+1}'] = nn.Linear(832, 1)self.LogTanh = LogTanh()self.flatten = nn.Flatten()

And this is the function to synchronize the weights:

这是同步权重的函数:

def sync_weights(models, current_sub, sync_seqs):for sub in range(1, 9):if sub != current_sub:# Synchronize the specified layerswith torch.no_grad():for seq_name in sync_seqs:reference_layer = getattr(models[current_sub], seq_name)[2]layer = getattr(models[sub], seq_name)[2]layer.weight.data = reference_layer.weight.data.clone()if layer.bias is not None:layer.bias.data = reference_layer.bias.data.clone()

then an error is raised: 然后出现了一个错误:

'Conv2d' object is not iterable

which means the getattr() returns a Conv2D object. But if I remove [2]:

意思是 getattr() 函数返回了一个 Conv2D 对象。但是,如果我移除了 [2]

def sync_weights(models, current_sub, sync_seqs):for sub in range(1, 9):if sub != current_sub:# Synchronize the specified layerswith torch.no_grad():for seq_name in sync_seqs:reference_layer = getattr(models[current_sub], seq_name)layer = getattr(models[sub], seq_name)layer.weight.data = reference_layer.weight.data.clone()if layer.bias is not None:layer.bias.data = reference_layer.bias.data.clone()

I get another error:        我得到了另一个错误

'Sequential' object has no attribute 'weight'

which means the getattr() returns a Sequential. But previously it returns a Conv2D object. Does anyone know anything about this? For your information, the sync_seqs parameter passed in sync_weights is:

意思是 getattr() 现在返回的是一个 Sequential 模型,但之前它返回的是一个 Conv2D 对象。有人知道这是怎么回事吗?为了提供更多信息,sync_weights 函数中传入的 sync_seqs 参数是:

sync_seqs = ['feat_ex1','feat_ex2','feat_ex3','feat_ex4','feat_ex5','feat_ex6','feat_ex7'
]

问题解决:

In both instances, getattr is returning a Sequential, which in turn contains a bunch of objects. In the second case, you're directly assigning that Sequential to a variable, so reference_layer ends up containing a Sequential.

在这两种情况下,getattr 都返回了一个 Sequential 对象,而这个 Sequential 对象又包含了一系列的其他对象。在第二种情况下,你直接将这个 Sequential 对象赋值给了一个变量,因此 reference_layer 最终包含了一个 Sequential 对象。

In the first case, however, you're not doing that direct assignemnt. You're taking the Sequential object and then indexing it with [2]. That means reference_layer contains the third item in the Sequential, which is a Conv2d object.

在第一种情况下,你没有进行直接的赋值。你是先获取了 Sequential 对象,然后使用 [2] 对其进行索引。这意味着 reference_layer 包含的是 Sequential 中的第三个项目,这个项目是一个 Conv2D 对象。

Take a more simple example. Suppose I had a ListContainer class that did nothing except hold a list. I could then recreate your example as follows, with test1 corresponding to your first test case and vice versa:

以一个更简单的例子来说明。假设我有一个 ListContainer 类,它唯一的作用就是持有一个列表。然后我可以按照以下方式重现你的例子,其中 test1 对应你的第一个测试用例,反之亦然:

class ListContainer:def __init__(self, list_items):self.list_items = list_itemsletters = ["a", "b", "c"]
container = ListContainer(letters)test1 = getattr(container, "list_items")[0]
test2 = getattr(container, "list_items")print(type(test1)) # <class 'str'>
print(type(test2)) # <class 'list'>

In both tests, getattr itself is returning a list - but in the second, we're doing something with that list after we get it, so test2 ends up being a string instead.

在两次测试中,getattr 本身都返回了一个列表——但在第二次测试中,我们在获取到这个列表之后对它进行了某种操作,所以 test2 最终变成了一个字符串而不是列表。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/384326.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WARNING: Ignoring invalid distribution -ip警告信息如何去掉?

查看已安装依赖列表的时候&#xff0c;出现了很多警告信息&#xff0c;如何去掉呢&#xff1f; 解决办法 打开这个路径&#xff1a;d:\software\python\python39\lib\site-packages 这种波浪线开头的&#xff0c;我们将它删除掉,就可以了。

Ubuntu设置网络

进入网络配置文件夹 cd /etc/netplan 使用 vim 打开下的配置文件 打开后的配置 配置说明&#xff1a; network:# 网络配置部分ethernets:# 配置名为ens33的以太网接口ens33:addresses:# 为ens33接口分配IP地址192.168.220.30&#xff0c;子网掩码为24位- 192.168.220.30/24n…

VS2019报错:找不到导入的项目,请确认import声明

解决办法 找到项目的.vcxproj文件 用记事本打开后使用ctrlF搜索import 发现import Project后面的.props文件路径不对&#xff0c;将路径改为相对路径 保存后重新加载项目&#xff0c;即可生成成功

AI发展下的伦理挑战:构建未来科技的道德框架

一、引言 随着人工智能&#xff08;AI&#xff09;技术的飞速发展&#xff0c;我们正处在一个前所未有的科技变革时代。AI不仅在医疗、教育、金融、交通等领域展现出巨大的应用潜力&#xff0c;也在日常生活中扮演着越来越重要的角色。然而&#xff0c;这一技术的迅猛进步也带来…

git实践汇总【配置+日常使用+问题解决】

**最初配置步骤&#xff1a;** git config --global user.name "yournemae" git config --global user.email "yourmail" git config -l ssh-keygen -t rsa -C “xxx.xxxx.EXTcccc.com” git config --global ssh.variant ssh $ git clone git仓库路径 git…

云盘高速检测的秘密:密封圈外观检测全解析!

密封圈是一种用于填塞、隔离或密封两个相互连接部件之间空隙的圆形密封装置。密封圈通常由橡胶、塑料、金属等材料制成&#xff0c;具有弹性并能在压力作用下填充间隙&#xff0c;防止液体、气体或固体物质泄漏。 密封圈可根据具体应用选择不同材料&#xff0c;如橡胶密封圈适…

「Unity3D」场景中的距离单位Unit与相关设置PixelsToUnits、PixelsPerUnit

GameObject在场景的位置Position&#xff0c;并没有明确是什么具体单位——如&#xff1a;Transform的x、y、z&#xff0c;或RectTransform的PosX、PosY、PosZ。而RectTransform在面板上显示的Width和Height&#xff0c;也没有具体单位&#xff0c;其实并不是像素。 事实上&am…

谷粒商城实战笔记-59-商品服务-API-品牌管理-使用逆向工程的前后端代码

文章目录 一&#xff0c; 使用逆向工程生成的代码二&#xff0c;生成品牌管理菜单三&#xff0c;几个小问题 在本次的技术实践中&#xff0c;我们利用逆向工程的方法成功地为后台管理系统增加了品牌管理功能。这种开发方式不仅能快速地构建起功能模块&#xff0c;还能在一定程度…

qemu上运行android-x86 (基于ubuntu)

安装qemu x86_64 下载qemu源代码 进入根目录&#xff0c;执行./configure --target-listx86_64-softmmu make -j4 sudo make install 期间可能遇到python相关问题&#xff0c;比如版本不对、库找不到&#xff0c;版本不对可自行安装或是修改环境变量&#xff0c;库找不到可以检…

如何检查代理IP地址是否被占用

使用代理IP时&#xff0c;有时候会发现IP仍然不可用&#xff0c;可能是因为已经被其他用户或者网络占用了。为了检测代理IP是否被占用&#xff0c;我们可以采用一些方法进行验证测试&#xff0c;以保证代理IP的有效性和稳定性。 1.ARP缓存方法 ARP缓存法是一种简单有效的检测代…

unity 实现图片的放大与缩小(根据鼠标位置拉伸放缩)

1创建UnityHelper.cs using UnityEngine.Events; using UnityEngine.EventSystems;public class UnityHelper {/// <summary>/// 简化向EventTrigger组件添加事件的操作。/// </summary>/// <param name"_eventTrigger">要添加事件监听的UI元素上…

将Android Library项目发布到JitPack仓库

将项目代码导入Github 1.将本地项目目录初始化为 Git 仓库。 默认情况下&#xff0c;初始分支称为 main; 如果使用 Git 2.28.0 或更高版本&#xff0c;则可以使用 -b 设置默认分支的名称。 git init -b main 如果使用 Git 2.27.1 或更低版本&#xff0c;则可以使用 git symbo…

ffmpeg更改视频的帧率

note 视频帧率调整 帧率(fps-frame per second) 例如&#xff1a;原来帧率为30&#xff0c;调整后为1 现象&#xff1a;原来是每秒有30张图像&#xff0c;调整后每秒1张图像&#xff0c;看着图像很慢 实现&#xff1a;在每秒的时间区间里&#xff0c;取一张图像…

Vue--解决error:0308010C:digital envelope routines::unsupported

原文网址&#xff1a;Vue--解决error:0308010C:digital envelope routines::unsupported_IT利刃出鞘的博客-CSDN博客 简介 本文介绍如何解决node.js在运行Vue项目时的报错&#xff1a;error:0308010C:digital envelope routines::unsupported。 问题描述 使用node.js运行Vu…

PyTorch深度学习实战——使用深度Q学习进行Pong游戏

PyTorch深度学习实战——使用深度Q学习进行Pong游戏 0. 前言1. 结合固定目标网络的深度 Q 学习模型1.1 模型输入1.2 模型策略2. 实现深度 Q 学习进行 Pong 游戏相关链接0. 前言 我们已经学习了如何利用深度 Q 学习来进行 Gym 中的 CartPole 游戏。在本节中,我们将研究更复杂的…

序列化与反序列化的本质

1. 将对象存储到本地 假如有一个student类&#xff0c;我们定义了好几个对象&#xff0c;想要把这些对象存储下来&#xff0c;该怎么办呢 from typing import List class Student:name: strage: intphones: List[str] s1 Student("xiaoming",10,["huawei&quo…

什么是护网?2024护网行动怎么参加?一文详解_护网具体是做啥的

前言 最近的全国护网可谓是正在火热的进行中&#xff0c;有很多网安小白以及准大一网安的同学在后台问我&#xff0c;到底什么是护网啊&#xff1f;怎么参加呢&#xff1f;有没有相关的学习资料呢&#xff1f;在下不才&#xff0c;连夜整理出来了这篇护网详解文章&#xff0c;希…

逆向案例二十九——某品威客登录,请求头参数加密,简单webpack

网址&#xff1a;登录- 一品威客网,创新型知识技能共享服务平台 抓到登陆包分析&#xff0c;发现请求头有参数加密&#xff0c;直接搜索 定位到加密位置&#xff0c;打上断点&#xff0c;很明显是对象f的a方法进行了加密。 往上找f&#xff0c;可以发现f被定义了&#xff0c;是…

【netty系列-05】深入理解直接内存与零拷贝

Netty系列整体栏目 内容链接地址【一】深入理解网络通信基本原理和tcp/ip协议https://zhenghuisheng.blog.csdn.net/article/details/136359640【二】深入理解Socket本质和BIOhttps://zhenghuisheng.blog.csdn.net/article/details/136549478【三】深入理解NIO的基本原理和底层…

Spring MVC 应用分层

1. 类名使⽤⼤驼峰⻛格&#xff0c;但以下情形例外&#xff1a;DO/BO/DTO/VO/AO 2. ⽅法名、参数名、成员变量、局部变量统⼀使⽤⼩驼峰⻛格 3. 包名统⼀使⽤⼩写&#xff0c;点分隔符之间有且仅有⼀个⾃然语义的英语单词. 常⻅命名命名⻛格介绍 ⼤驼峰: 所有单词⾸字⺟…