深度对抗神经网络(DANN)笔记


一 总体介绍

DANN是一种迁移学习方法,是对抗迁移学习方法的代表方法。基本结构由特征提取层f,分类器部分c和对抗部分d组成,其中f和c其实就是一个标准的分类模型,通过GAN(生成对抗网络)得到迁移对抗模型的灵感。但此时生成的不是假样本,而是假特征,一个足以让目标域和源域区分不开的假特征。

而领域判别器D其实是个标准的二分类分类器,0是源域,1是目标域。它本身的目标是区分源域和目标域,而我们想要的结果是使判别器越来越分不出数据特征来自源域还是目标域,感觉起来这很矛盾。但其实我们引入一个梯度反转层就可以完美避免这个问题。

引入梯度反转层(GRL),分类器c和判别器d朝着优化分类器效果的方向反向传播优化梯度。有了梯度反转层,简单的说就是判别器d反向传播时,梯度更新前引入了一个“ - ”。这样就可以同时满足判别器和我们需求的一致性。

下面是DANN的基本网络图。


二 UDTL代码库中的DANN网络

对抗网络部分代码。

from torch import nn
import numpy as npdef calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)
### 如果self.trade_off_adversarial == 'Step',则调用此函数得到coeff的值,不然self.trade_off_adversarial == 'Cons',则coeff是个固定的值
###coeff——————coeff = self.lam_adversarial 其中trade_off_adverial充当域分类器部分的学习率随着迭代过程会逐渐递减——————学习率def grl_hook(coeff):#补充连接,因为是引入块,需要连接到model层的梯度:  grad.clone()def fun1(grad):return -coeff * grad.clone()return fun1class AdversarialNet(nn.Module):def __init__(self, in_feature, hidden_size,max_iter=10000.0, trade_off_adversarial='Step', lam_adversarial=1.0):super(AdversarialNet, self).__init__()self.ad_layer1 = nn.Sequential(nn.Linear(in_feature, hidden_size),nn.ReLU(inplace=True),nn.Dropout(),)self.ad_layer2 = nn.Sequential(nn.Linear(hidden_size, hidden_size),nn.ReLU(inplace=True),nn.Dropout(),)self.ad_layer3 = nn.Linear(hidden_size, 1)self.sigmoid = nn.Sigmoid()# parametersself.iter_num = 0self.alpha = 10self.low = 0.0self.high = 1.0self.max_iter = max_iterself.trade_off_adversarial = trade_off_adversarialself.lam_adversarial = lam_adversarialself.__in_features = 1def forward(self, x):if self.training:self.iter_num += 1if self.trade_off_adversarial == 'Cons':coeff = self.lam_adversarialelif self.trade_off_adversarial == 'Step':coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter) #学习率else:raise Exception("loss not implement")x = x * 1.0x.register_hook(grl_hook(coeff))#register_hook的作用:即对x求导时,对x的导数进行操作,并且register_hook的参数只能以函数的形式传过去,#grl_hook(coeff)则返回的是梯度  *  “——”梯度反转层作用##register_hook的作用:对x求导,并将梯度保存下来,这样可以作为参数通过优化器通过反向传播过程进行更新优化,实现DANN所需效果x = self.ad_layer1(x)x = self.ad_layer2(x)y = self.ad_layer3(x)y = self.sigmoid(y)return ydef output_num(self):return self.__in_features#输出通道是1 代表域判别值0:源域 1:目标域

以上是赵志斌老师UDTL代码中的对抗网络部分的介绍,用于故障诊断数据。

这里要声明的是对于AdversarialNet网络而言,一维的数据和二维数据都可以拿来直接使用,实质上它仅仅是多出了一个二分类判别器和一个梯度反转层而已。
ZhaoZhibin/UDTL: Source codes for the paper "Applications of Unsupervised Deep Transfer Learning to Intelligent Fault Diagnosis: A Survey and Comparative Study" published in TIM (github.com)https://github.com/ZhaoZhibin/UDTL


 三 网络的其他写法

这里博主还找到了另外一种的对抗网络写法。

import torch.nn as nn
from functions import ReverseLayerF#从functions中导入梯度反转层这一类class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.feature = nn.Sequential()self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))#这里是因为数据是mnist数据所以输入通道为3self.feature.add_module('f_bn1', nn.BatchNorm2d(64))self.feature.add_module('f_pool1', nn.MaxPool2d(2))self.feature.add_module('f_relu1', nn.ReLU(True))self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))self.feature.add_module('f_bn2', nn.BatchNorm2d(50))self.feature.add_module('f_drop1', nn.Dropout2d())self.feature.add_module('f_pool2', nn.MaxPool2d(2))self.feature.add_module('f_relu2', nn.ReLU(True))
#上面是backbone部分也是网络的特征提取部分self.class_classifier = nn.Sequential()self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))self.class_classifier.add_module('c_relu1', nn.ReLU(True))self.class_classifier.add_module('c_drop1', nn.Dropout())self.class_classifier.add_module('c_fc2', nn.Linear(100, 100))self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100))self.class_classifier.add_module('c_relu2', nn.ReLU(True))self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1))
#上面是源域的分类器部分,只要是要对源域数据进行有效的分类self.domain_classifier = nn.Sequential()self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))self.domain_classifier.add_module('d_relu1', nn.ReLU(True))self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))
#上面是领域判别器部分,主要任务是要区分出源域和目标域def forward(self, input_data, alpha):input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)feature = self.feature(input_data)feature = feature.view(-1, 50 * 4 * 4)reverse_feature = ReverseLayerF.apply(feature, alpha)
#前向网络中注意到,reverse_feature是通过ReverseLayerF.apply将feature进行反向的梯度计算。class_output = self.class_classifier(feature)domain_output = self.domain_classifier(reverse_feature)
#并将处理过的reverse_feature特征给domain_classifer进行域判别。return class_output, domain_output

对应的ReverseLayerF部分代码: 

from torch.autograd import Functionclass ReverseLayerF(Function):@staticmethoddef forward(ctx, x, alpha):ctx.alpha = alphareturn x.view_as(x)@staticmethoddef backward(ctx, grad_output):output = grad_output.neg() * ctx.alpha
####grad_output.neg()梯度取负操作,反向内容的核心。return output, None

代码地址:https://github.com/fungtion/DANN_py3

通过介绍以上两种不同写法的对抗网络模型,相信你也可以看到对抗网络的核心其实很简单。

仅仅是多出了一个领域判别器和一个梯度反转层。

但采用对抗网络作为迁移网络方法又能很好的解决很多域迁移领域的问题,特别是在域之间的差异较大的情况时,往往要比以MMD(最大均值差异)为代表的度量学习方法效果要好。

以上是我学习过程中对DANN进行的一些总结工作,欢迎评论区讨论交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/60840.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习:BP神经网络,CNN卷积神经网络,GAN生成对抗网络

1,基础知识 1.1,概述 机器学习:概念_燕双嘤-CSDN博客1,机器学习概述1.1,机器学习概念机器学习即Machine Learning,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。目的是让计算机模拟或实…

生成对抗网络(GAN)简单梳理

作者:xg123321123 - 时光杂货店 出处:http://blog.csdn.net/xg123321123/article/details/78034859 声明:版权所有,转载请联系作者并注明出处 网上已经贴满了关于GAN的博客,写这篇帖子只是梳理下思路,以便以…

生成对抗网络(GAN)简介以及Python实现

本篇博客简单介绍了生成对抗网络(Generative Adversarial Networks,GAN),并基于Keras实现深度卷积生成对抗网络(DCGAN)。 以往的生成模型都是预先假设生成样本服从某一分布族,然后用深度网络学习分布族的参数,最后从学习到的分布中采样生成新…

生成对抗网络(GAN)教程 - 多图详解

一.生成对抗网络简介 1.生成对抗网络模型主要包括两部分:生成模型和判别模型。 生成模型是指我们可以根据任务、通过模型训练由输入的数据生成文字、图像、视频等数据。 [1]比如RNN部分讲的用于生成奥巴马演讲稿的RNN模型,通过输入开头词就能生成下来。…

对抗神经网络学习和实现(GAN)

一,GAN的原理介绍 \quad GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是…

生成对抗网络(GAN)

1 GAN基本概念 1.1 如何通俗理解GAN? ​ 生成对抗网络(GAN, Generative adversarial network)自从2014年被Ian Goodfellow提出以来,掀起来了一股研究热潮。GAN由生成器和判别器组成,生成器负责生成样本,判别器负责判断生成器生成…

基于图神经网络的对抗攻击 Nettack: Adversarial Attacks on Neural Networks for Graph Data

研究意义 随着GNN的应用越来越广,在安全非常重要的应用中应用GNN,存在漏洞可能是非常严重的。 比如说金融系统和风险管理,在信用评分系统中,欺诈者可以伪造与几个高信用客户的联系,以逃避欺诈检测模型;或者…

生成对抗网络(Generative Adversial Network,GAN)原理简介

生成对抗网络(GAN)是深度学习中一类比较大的家族,主要功能是实现图像、音乐或文本等生成(或者说是创作),生成对抗网络的主要思想是:通过生成器(generator)与判别器(discriminator)不断对抗进行训练。最终使得判别器难以分辨生成器生成的数据(…

快讯|莫言用 ChatGPT 写《颁奖辞》;特斯拉人形机器人集体出街!已与FSD算法打通

一分钟速览新闻点 言用 ChatGPT 写《颁奖辞》孙其君研究员团队 Adv. Funct. Mater.:多功能离子凝胶纤维膜用于能量离电皮肤微软CEO反驳马斯克:我们没有控制OpenAI特斯拉人形机器人集体出街!已与FSD算法打通微软CEO称小型公司仍可在人工智能领…

从供应链到价值链:人形机器人产业链深入研究

原创 | 文 BFT机器人 01 人形机器人产业进展:AI赋能,人形机器人迭代有望加速 目前人形机器人产业所处从“0”到“1”的萌芽期,从现在到未来的时间里,人形机器人以其仿人外形、身体构成及其智能大脑,能极大解放生产力、…

人形机器人火出圈!OpenAI领投挪威人形机器人公司“1X”

文|牛逼的AI 编|猫猫咪子 源|AI源起 目前已经实现了对接GPT4技术的机器人Ameca,它拥有逼真的外观和丰富的表情。 随着人形机器人技术的飞速发展,未来可能不再适用阿西莫夫所提出的“机器人三定律”,因为超级…

chatgpt赋能python:Python登录界面制作指南

Python登录界面制作指南 介绍 登录界面是许多应用程序的关键组成部分之一。Python作为一种优秀的编程语言,拥有着强大的界面开发框架,能够帮助开发人员更轻松地创作出完美的登录界面。 在本文中,我们将向您介绍使用Python如何制作一个简单…

The Journal of Neuroscience: 珠心算训练有助于提高儿童的视觉空间工作记忆

《本文同步发布于“脑之说”微信公众号,欢迎搜索关注~~》 珠心算是指个体在熟练进行珠算操作后,可摆脱实际算盘,借助大脑中虚拟算盘进行数字计算的方式(图1)。早期行为学研究表明,珠心算个体的数字计算能力…

php珠心算源码,深度解析珠心算的“开智”功能

编者按:本文来自李绵军校长在廊坊智慧特训营演讲。李绵军校长通过十几年来对珠心算的钻研练习,详细解读了珠心算的开智功能,以及“一门深入”的作用。 珠心算的开智价值是在哪里?大家都说开发智力,我在这讲开发智力不是…

php珠心算源码,NOIP201401珠心算测验

珠心算测验 问题描述】 珠心算是一种通过在脑中模拟算盘变化来完成快速运算的一种计算技术。珠心算训练,既能够开发智力,又能够为日常生活带来很多便利,因而在很多学校得到普及。 某学校的珠心算老师采用一种快速考察珠心算加法能力的测验方…

雨课堂提交作业步骤 10步帮你弄好

1 2 3 4 5 6 7 中间计数记错了… 8 9 10 弹出对话框,点击确认即可 提交成功的截图:

2022李宏毅作业hw4 - 挫败感十足的一次作业。

系列文章: 2022李宏毅作业hw1—新冠阳性人员数量预测。_亮子李的博客-CSDN博客_李宏毅hw1 hw-2 李宏毅2022年作业2 phoneme识别 单strong-hmm详细解释。_亮子李的博客-CSDN博客_李宏毅hw2 2021李宏毅作业hw3 --食物分类。对比出来的80准确率。_亮子李的博客-CSDN博客…

php老师的一个作业展示

1.在ScanCode.php中 在judgeTrayCodeEnableIntoWarehouse方法中: 2在CommitTrayCodes.php中 访问时,直接访问CommitTrayCodes.php,这个CommitTrayCodes.php是要建在controller下,建议导一下,老师的数据库,以防止命名…

HCIA网络课程第七周作业

(1)请用自己的语言描述基本ACL和高级ACL的区别 (2)AAA支持的认证、授权和计费方式分别有哪几种? AAA支持的认证方式有不认证 本地认证 远端认证AAA支持的授权方式为不授权 本地授权 远端授权AAA支持计费方式为不计费…

如何获取抖音和快手直播间的直播流地址

如下是通过python代码脚本获取的方法: import requests import re def get_real_url(rid): try: if ‘v.douyin.com‘ in rid: room_id re.findall(r‘(\d{19})‘, requests.get(urlrid).url)[0] else: room_id rid room_url ‘https://webcast-hl.amemv.com/…