Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method--论文笔记

论文笔记

资料

1.代码地址

https://github.com/iBelieveCJM/pseudo_label-pytorch

2.论文地址

3.数据集地址

论文摘要的翻译

本文提出了一种简单有效的深度神经网络半监督学习方法。基本上,所提出的网络是以有监督的方式同时使用标记数据和未标记数据来训练的。对于未标记的数据,只要选取具有最大预测概率的类别,就可以使用伪标签,就好像它们是真标签一样。这实际上等同于熵正则化。它支持类之间的低密度分离,这是半监督学习通常假设的先验条件。在MNIST手写数字数据集上,利用去噪自动编码器和丢弃,这种简单的方法在标签数据非常少的情况下优于传统的半监督学习方法。

1背景

所有训练深度神经网络的成功方法都有一个共同点:它们都依赖于无监督学习算法。大多数工作分两个主要阶段进行。在第一阶段,无监督预训练,所有层的权重通过这种分层的无监督训练来初始化。在第二阶段,微调,在有监督的方式下,使用反向传播算法用标签全局地训练权值。所有这些方法也都以半监督的方式工作。我们只需要使用额外的未标记数据来进行无监督的预训练。
我们提出了一种更简单的半监督方式训练神经网络的方法。基本上,所提出的网络是以有监督的方式同时使用标记数据和未标记数据来训练的。对于未标记的数据,只需选取每次权重更新具有最大预测概率的类,就像使用真标签一样使用伪标签。该方法原则上可以结合几乎所有的神经网络模型和训练方法。
这种方法实际上等同于熵正则化(Granvalet等人,2006年)。类概率的条件熵可用于类重叠的度量。通过最小化未标记数据的熵,可以减少类概率分布的重叠性。它支持类之间的低密度分离,这是半监督学习的常见先验假设。

2论文的创新点

3 论文方法的概述

3.1 思路

伪标签是未标记数据的目标类,就好像它们是真标签一样。我们只选取对每个未标记样本具有最大预测概率的类别。
y i ′ = { 1 if  i = argmax i ′ f i ′ ( x ) 0 otherwise y_i^{\prime}=\begin{cases}1&\text{if }i=\text{argmax}_{i'}f_{i'}(x)\\0&\text{otherwise}\end{cases} yi={10if i=argmaxifi(x)otherwise我们在Dropout的微调阶段使用伪标签。用标记和未标记的数据同时以有监督的方式训练预先训练的网络。对于未标记的数据,每次权值更新重新计算的伪标签被用于相同的监督学习任务的损失函数。
由于有标签数据和无标签数据的总数有很大不同,并且它们之间的训练平衡对网络性能非常重要,因此总体损失函数为 L = 1 n ∑ m = 1 n ∑ i = 1 C L ( y i m , f i m ) + α ( t ) 1 n ′ ∑ m = 1 n ′ ∑ i = 1 C L ( y i ′ m , f i ′ m L=\frac{1}{n}\sum_{m=1}^{n}\sum_{i=1}^{C}L(y_{i}^{m},f_{i}^{m})+\alpha(t)\frac{1}{n'}\sum_{m=1}^{n'}\sum_{i=1}^{C}L(y_{i}^{\prime m},f_{i}^{\prime m} L=n1m=1ni=1CL(yim,fim)+α(t)n1m=1ni=1CL(yim,fim
其中n是SGD的已标记数据中的批次数, n ′ n\prime n用于未标记数据, f i m f^m_i fim是已标记数据中 m m m个样本的输出单位, y i m y^m_i yim是标签, f i ′ m f^{\prime m}_{i} fim用于未标记数据, y i ′ m y^{\prime m}_{i} yim是未标记数据的伪标签, α ( t ) \alpha(t) α(t)是平衡它们的系数。
α ( t ) \alpha(t) α(t)的合理调度对网络性能非常重要。如果 α ( t ) \alpha(t) α(t)太高,即使对于已标记的数据,也会干扰训练。而如果 α ( t ) \alpha(t) α(t)太小了,我们就不能利用未标记数据的好处。此外, α ( t ) \alpha(t) α(t)缓慢增加的确定性退火过程有望帮助优化过程避免较差的局部极小值,从而使未标记数据的伪标签尽可能类似于真实标签。 α ( t ) = { 0 t < T 1 t − T 1 T 2 − T 1 α f T 1 ≤ t < T 2 α f T 2 ≤ t \alpha(t)=\begin{cases}0&t<T_1\\\frac{t-T_1}{T_2-T_1}\alpha_f&T_1\leq t<T_2\\\alpha_f&T_2\leq t\end{cases} α(t)= 0T2T1tT1αfαft<T1T1t<T2T2t α f {\alpha}_f αf=3、 T 1 T_1 T1=100、 T 2 T_2 T2=600的情况下,不进行预训练;在DAE的情况下, T 1 T_1 T1=200、 T 2 T_2 T2=800。

3.2 Pseudo-Label为什么有效?

半监督学习的目标是利用未标记的数据来提高泛化性能。聚集学习假设指出,决策边界应位于低密度区域,以提高泛化性能。
最近提出的使用流形学习训练神经网络的方法,如半监督嵌入和流形切线分类器,都利用了这一假设。半监督嵌入使用基于嵌入的正则化来提高深度神经网络的泛化性能。由于数据样本的邻居通过嵌入惩罚项与样本具有相似的激活,因此高密度区域的数据样本更有可能具有相同的标签。流形切线分类器鼓励网络输出对低维流形方向的变化不敏感。因此,同样的目的也达到了。

3.3 Entropy Regularization

在最大后验估计的框架下,熵正则化是一种从未标记数据中获益的方法。该方案通过最小化未标记数据的类概率的条件熵来支持类之间的低密度分离,而不需要对密度进行任何建模。 H ( y ∣ x ′ ) = − 1 n ′ ∑ m = 1 n ′ ∑ i = 1 C P ( y i m = 1 ∣ x ′ m ) log ⁡ P ( y i m = 1 ) H(y|x')=-\frac{1}{n'}\sum_{m=1}^{n'}\sum_{i=1}^{C}P(y_{i}^{m}=1|x'^{m})\operatorname{log}P(y_{i}^{m}=1) H(yx)=n1m=1ni=1CP(yim=1∣xm)logP(yim=1)
其中 n ′ n^\prime n是未标记数据的数目, C C C是类数, y i m y^m_i yim是第 m m m个未标记样本的未知标记, x ′ m x^{\prime m} xm是第m个未标记样本的输入向量,熵是类重叠的一种度量。随着类重叠的减少,决策边界上的数据点密度变得更低。
MAP估计被定义为后验分布的最大值: C ( θ , λ ) = ∑ m = 1 n log ⁡ P ( y m ∣ x m ; θ ) − λ H ( y ∣ x ′ ; θ ) C(\theta,\lambda)=\sum_{m=1}^n\log P(y^m|x^m;\theta)-\lambda H(y|x';\theta) C(θ,λ)=m=1nlogP(ymxm;θ)λH(yx;θ)
其中n是标记数据的数目, x m x^m xm是第 m m m个标记样本, λ λ λ是平衡两项的系数。通过最大化已标记数据(第一项)的条件对数似然和最小化未标记数据(第二项)的熵,可以获得更好的泛化性能。
图1示出了t-SNE 在MNIST测试数据(未包括在未标记数据中)的网络输出的2D嵌入结果。神经网络用600个已标记数据以及60000个未标记数据和伪标签进行训练。虽然在两种情况下训练误差为零,但通过使用未标记数据和伪标签进行训练,测试数据的网络输出更接近于1-OFK码,换言之,(17)的熵被最小化。
在这里插入图片描述
表2显示了(17)的估计熵。虽然两种情况下已标记数据的熵都接近于零,但通过伪标签训练,未标记数据的熵变低,另外,测试数据的熵也随之降低。这使得分类问题变得更容易,甚至对于测试数据也是如此,并且使得决策边界处的数据点密度更低。根据聚类假设,我们可以得到更好的泛化性能。
在这里插入图片描述

4 实验

4.1 手写数字识别(MNIST)

MNIST是深度学习文献中最著名的数据集之一。为了进行比较,我们使用了相同的半监督设置。我们将标记的训练集的大小减少到100、600、1000和3000。训练集在每个标签上具有相同数量的样本。对于验证集,我们拿到了1000个标签验证。我们使用验证集来确定一些超参数。其余数据用于未标记的数据。由于我们不能得到相同的数据集分割,所以使用相同的网络和参数进行了10个随机分割实验。在100个标记数据的情况下,结果在很大程度上依赖于数据分割,因此进行了30个实验。对于100个标记数据,95%的可信区间约为±1∼1.5%,对于600个标记数据,约为±0.1∼0.我们使用具有1个隐含层的神经网络。激活函数使用Relu函数,输出单元采用Sigmoid单元。隐藏单元的数量为5000个。15%,对于1000和3000个标记数据,小于±0.1%。为了进行优化,我们使用了带Dropout1的小批次随机梯度下降算法,初始学习率为1.5,有标签数据的小批次数为32,未标签数据的小批次数为256。这些参数是使用验证集确定的。表2将我们的方法与其他方法的结果进行了比较。我们的方法虽然简单,但在处理小标签数据时表现优于传统的方法。该训练方案比流形切线分类器简单,并且不使用半监督嵌入中使用的样本之间的计算代价高昂的相似性矩阵。在这里插入图片描述

5总结

在这项工作中,我们展示了一种简单有效的神经网络半监督学习方法。该方法在不需要复杂的训练方案和计算代价较高的相似矩阵的情况下,具有最好的性能。

6复现

def train(net, trainloader, criterion, n_labeled, labeled_bs, device='cuda'):'''训练网络n_labeled: 标记数据的数量labeled_bs: 标记数据的批量大小'''learning_rate = 1e-3  # 学习率epochs = 200  # 训练的轮数optimizer = optim.Adam(net.parameters(), lr=learning_rate)  # Adam优化器T1 = 100  # 第一个阈值T2 = 600  # 第二个阈值alpha = 0  # 初始alpha值af = 0.3  # alpha的最大值for epoch in range(epochs):net.train()  # 设置网络为训练模式running_loss = 0.0  # 累积的损失running_acc = 0.0  # 累积的准确度for img, label in trainloader:img = img.to(device)  # 将图片移动到指定设备label = label.to(device)  # 将标签移动到指定设备# 将数据分成标记数据和未标记数据img_labeled = img[:labeled_bs]  # 获取标记数据label_labeled = label[:labeled_bs]  # 获取标记数据的标签img_unlabeled = img[labeled_bs:]  # 获取未标记数据# 标记数据的前向传播output_labeled = net(img_labeled)  # 计算标记数据的输出# 未标记数据的前向传播if epoch > T1:alpha = (epoch - T1) / (T2 - T1) * af  # 计算alpha值if epoch > T2:alpha = af  # 设置alpha为最大值# 为未标记数据生成伪标签output_unlabeled = net(img_unlabeled)  # 计算未标记数据的输出_, pseudo_labels = torch.max(output_unlabeled, 1)  # 获取伪标签# 半监督损失loss = criterion(output_labeled, label_labeled) + alpha * criterion(output_unlabeled, pseudo_labels)  # 计算总损失# 反向传播并更新网络optimizer.zero_grad()  # 清空梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数# 计算训练损失和训练准确度running_loss += loss.item()  # 累积损失_, predicted = torch.max(output_labeled, 1)  # 获取预测结果correct_num = (predicted == label_labeled).sum().item()  # 计算正确预测的数量running_acc += correct_num  # 累积准确度running_loss /= n_labeled  # 计算平均损失running_acc /= n_labeled  # 计算平均准确度print('Train[%d / %d] loss: %.5f, Acc: %.2f%%' % (epoch + 1, epochs, running_loss, 100 * running_acc))  # 打印当前轮次的损失和准确度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/370251.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ASCII码对照表(Matplotlib颜色对照表)

文章目录 1、简介1.1 颜色代码 2、Matplotlib库简介2.1 简介2.2 安装2.3 后端2.4 入门例子 3、Matplotlib库颜色3.1 概述3.2 颜色图的分类3.3 颜色格式表示3.4 内置颜色映射3.5 xkcd 颜色映射3.6 颜色命名表 4、Colorcet库5、颜色对照表结语 1、简介 1.1 颜色代码 颜色代码是…

2024亚太杯数学建模竞赛(B题)的全面解析

你是否在寻找数学建模比赛的突破点&#xff1f;数学建模进阶思路&#xff01; 作为经验丰富的数学建模团队&#xff0c;我们将为你带来2024亚太杯数学建模竞赛&#xff08;B题&#xff09;的全面解析。这个解决方案包不仅包括完整的代码实现&#xff0c;还有详尽的建模过程和解…

数据库-MySQL 实战项目——书店图书进销存管理系统数据库设计与实现(附源码)

一、前言 该项目非常适合MySQL入门学习的小伙伴&#xff0c;博主提供了源码、数据和一些查询语句&#xff0c;供大家学习和参考&#xff0c;代码和表设计有什么不恰当还请各位大佬多多指点。 所需环境 MySQL可视化工具&#xff1a;navicat&#xff1b; 数据库&#xff1a;MySq…

MySQL:如何在已经使用的数据表中增加一个自动递增的字段

目录 一、需求 二、实现步骤 &#xff08;一&#xff09;数据表students &#xff08;二&#xff09;添加整型字段 &#xff08;三&#xff09;更新SID字段的值 1、使用用户定义的变量和JOIN操作 2、用SET语句和rownum变量 &#xff08;1&#xff09;操作方法 &#x…

使用antd的<Form/>组件获取富文本编辑器输入的数据

前端开发中&#xff0c;嵌入富文本编辑器时&#xff0c;可以通过富文本编辑器自身的事件处理函数将数据传输给后端。有时候&#xff0c;场景稍微复杂点&#xff0c;比如一个输入页面除了要保存富文本编辑器的内容到后端&#xff0c;可能还有一些其他输入组件获取到的数据也一并…

MySQL篇三:数据类型

文章目录 前言1. 数值类型1.1 tinyint类型1.2 bit类型1.3 小数类型1.3.1 float1.3.2 decimal 2. 字符串类型2.1 char2.2 varchar2.3 char和varchar比较 3. 日期类型4. enum和set 前言 数据类型分类&#xff1a; 1. 数值类型 1.1 tinyint类型 在MySQL中&#xff0c;整型可以指…

如何第一次从零上传项目到GitLab

嗨&#xff0c;我是兰若&#xff0c;今天想给大家说下&#xff0c;如何上传一个完整的项目到与LDAP集成的GitLab&#xff0c;也就是说这个项目之前是不在git上面的&#xff0c;这是第一次上传&#xff0c;这样上传上去之后&#xff0c;其他小伙伴就可以根据你这个项目的git地址…

自动批量将阿里云盘文件发布成WordPress文章脚本源码(以RiPro主题为例含付费信息下载地址SEO等自动设置)源码

背景 很多资源下载站&#xff0c;付费资源下载站&#xff0c;付费内容查看等都可以用WordPress站点发布内容&#xff0c;这些站点一般会基于一个主题&#xff0c;付费信息作为文章附属的信息发布&#xff0c;底层存储在WP表里&#xff0c;比如日主题&#xff0c;子比主题等。 …

2-5 softmax 回归的简洁实现

我们发现通过深度学习框架的高级API能够使实现线性回归变得更加容易。 同样&#xff0c;通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节如在上节中一样&#xff0c; 继续使用Fashion-MNIST数据集&#xff0c;并保持批量大小为256。 import torch from torc…

【基础篇】第4章 Elasticsearch 查询与过滤

在Elasticsearch的世界里&#xff0c;高效地从海量数据中检索出所需信息是其核心价值所在。本章将深入解析查询与过滤的机制&#xff0c;从基础查询到复合查询&#xff0c;再到全文搜索与分析器的定制&#xff0c;为你揭开数据检索的神秘面纱。 4.1 基本查询 4.1.1 Match查询…

多语言版在线出租车预订完整源码+用户应用程序+管理员 Laravel 面板+ 司机应用程序最新版源码

源码带PHP后台客户端源码 Flutter 是 Google 开发的一款开源移动应用开发 SDK。它用于开发 Android 和 iOS 应用&#xff0c;也是为 Google Fuchsia 创建应用的主要方法。Flutter 小部件整合了所有关键的平台差异&#xff0c;例如滚动、导航、图标和字体&#xff0c;可在 iOS 和…

如何在前端网页实现live2d的动态效果

React如何在前端网页实现live2d的动态效果 业务需求&#xff1a; 因为公司需要做机器人相关的业务&#xff0c;主要是聊天形式的内容&#xff0c;所以需要一个虚拟的卡通形象。而且为了更直观的展示用户和机器人对话的状态&#xff0c;该live2d动画的嘴型需要根据播放的内容来…

昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练 文章目录 昇思25天学习打卡营第08天 | 模型训练超参数损失函数优化器优化过程 训练与评估总结打卡 模型训练一般遵循四个步骤&#xff1a; 构建数据集定义神经网络模型定义超参数、损失函数和优化器输入数据集进行训练和评估 构建数据集和…

【Excel、RStudio计算T检测的具体操作步骤】

目录 一、基础知识1.1 显著性检验1.2 等方差T检验、异方差T检验1.3 单尾p、双尾p1.3.1 检验目的不同1.3.2 用法不同1.3.3 如何选择 二、Excel2.1 统计分析工具2.1.1 添加统计分析工具2.1.2 数据分析 2.2 公式 -> 插入函数 -> T.TEST 三、RStudio 一、基础知识 参考: 1.…

无人机运营合格证及无人机驾驶员合格证(AOPA)技术详解

无人机运营合格证及无人机驾驶员合格证&#xff08;AOPA&#xff09;技术详解如下&#xff1a; 一、无人机运营合格证 无人机运营合格证是无人机运营企业或个人必须获得的证书&#xff0c;以确保无人机在运营过程中符合相关法规和标准。对于无人机运营合格证的具体要求和申请…

【React】React18 Hooks之useState

目录 useState案例1&#xff08;直接修改状态&#xff09;案例2&#xff08;函数式更新&#xff09;案例3&#xff08;受控表单绑定&#xff09;注意事项1&#xff1a;set函数不会改变正在运行的代码的状态注意事项2&#xff1a;set函数自动批量处理注意事项3&#xff1a;在下次…

【反悔堆 优先队列 临项交换 决策包容性】630. 课程表 III

本文涉及知识点 贪心 反悔堆 优先队列 临项交换 Leetcode630. 课程表 III 这里有 n 门不同的在线课程&#xff0c;按从 1 到 n 编号。给你一个数组 courses &#xff0c;其中 courses[i] [durationi, lastDayi] 表示第 i 门课将会 持续 上 durationi 天课&#xff0c;并且必…

E4.【C语言】练习:while和getchar的理解

#include <stdio.h> int main() {int ch 0;while ((ch getchar()) ! EOF){if (ch < 0 || ch>9)continue;putchar(ch);}return 0; } 理解上述代码 0-->48 9-->57 if行判断是否为数字&#xff0c;打印数字&#xff0c;不打印非数字

Selenium 切换 frame/iframe

环境&#xff1a; Python 3.8 selenium3.141.0 urllib31.26.19说明&#xff1a; driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …

k8s公网集群安装(1.23.0)

网上搜到的公网搭建k8s都不太一致, 要么说的太复杂, 要么镜像无法下载, 所以写了一个简洁版,小白也能一次搭建成功 使用的都是centos7,k8s版本为1.23.0 使用二台机器搭建的, 三台也是一样的思路1.所有节点分别设置对应主机名 hostnamectl set-hostname master hostnamectl set…