pytorch实现半监督学习

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:

1. 数据准备

  • 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
  • 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
  • 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。

2. 选择半监督学习方法

常见的半监督学习方法包括:

  • 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
  • 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
  • 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
  • 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。

3. 训练初始模型

  • 仅使用有标签数据训练一个初始模型。
  • 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
  • 训练过程中可以使用数据增强、正则化等优化策略。

4. 利用无标签数据增强训练

  • 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
  • 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
  • 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。

5. 模型迭代更新

  • 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
  • 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。

6. 评估和测试

  • 使用测试集(通常是有标签的数据)评估模型性能。
  • 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。

7. 调优和部署

  • 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
  • 结合业务需求,将最终模型部署到实际应用中。

关键步骤:

  1. 初始化模型:首先使用有标签数据训练模型。
  2. 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。
  3. 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。
  4. 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt# 简化的神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 8, kernel_size=3)  # 缩小卷积层的输出通道self.fc1 = nn.Linear(8 * 26 * 26, 10)  # 调整全连接层的输入和输出尺寸def forward(self, x):x = F.relu(self.conv1(x))x = x.view(x.size(0), -1)  # 展平x = self.fc1(x)return x# 自定义数据集
class CustomDataset(Dataset):def __init__(self, data, labels=None):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):if self.labels is not None:return self.data[idx], self.labels[idx]else:return self.data[idx], -1  # 无标签数据# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):model.train()labeled_loss_value = 0pseudo_loss_value = 0for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)unlabeled_data = unlabeled_data.to(device)# 1. 有标签数据训练optimizer.zero_grad()labeled_output = model(labeled_data)labeled_loss = F.cross_entropy(labeled_output, labeled_labels)labeled_loss.backward()# 2. 无标签数据伪标签生成unlabeled_output = model(unlabeled_data)probs = F.softmax(unlabeled_output, dim=1)max_probs, pseudo_labels = torch.max(probs, dim=1)# 伪标签置信度筛选pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签if pseudo_mask.sum() > 0:pseudo_labels = pseudo_labels[pseudo_mask]unlabeled_data_pseudo = unlabeled_data[pseudo_mask]# 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)optimizer.zero_grad()  # 清除之前的梯度pseudo_output = model(unlabeled_data_pseudo)pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)pseudo_loss.backward()  # 计算反向梯度optimizer.step()  # 更新模型参数# 累加损失用于展示labeled_loss_value += labeled_loss.item()if pseudo_mask.sum() > 0:pseudo_loss_value += pseudo_loss.item()return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28)  # 28x28 灰度图像
num_classes = 10labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小# 模型、优化器和设备设置
device = torch.device("cpu")  # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []for epoch in range(num_epochs):labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)labeled_losses.append(labeled_loss)pseudo_losses.append(pseudo_loss)print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()# 展示伪标签生成效果(可视化一些样本的伪标签预测结果)
model.eval()
with torch.no_grad():sample_unlabeled_data = unlabeled_data[:10].to(device)output = model(sample_unlabeled_data)probs = F.softmax(output, dim=1)_, predicted_labels = torch.max(probs, dim=1)# 展示预测的标签print("Generated Pseudo Labels for Samples:")print(predicted_labels)# 假设这些是伪标签预测的图片fig, axes = plt.subplots(2, 5, figsize=(12, 5))for i, ax in enumerate(axes.flat):# 将tensor转换为NumPy数组img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组ax.imshow(img, cmap='gray')  # 使用灰度显示图像ax.set_title(f"Pred: {predicted_labels[i].item()}")ax.axis('off')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/9827.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

白嫖DeepSeek:一分钟完成本地部署AI

1. 必备软件 LM-Studio 大模型客户端DeepSeek-R1 模型文件 LM-Studio 是一个支持众多流行模型的AI客户端,DeepSeek是最新流行的堪比GPT-o1的开源AI大模型。 2. 下载软件和模型文件 2.1 下载LM-Studio 官方网址:https://lmstudio.ai 打开官网&#x…

冲刺蓝桥杯之速通vector!!!!!

文章目录 知识点创建增删查改 习题1习题2习题3习题4:习题5: 知识点 C的STL提供已经封装好的容器vector,也可叫做可变长的数组,vector底层就是自动扩容的顺序表,其中的增删查改已经封装好 创建 const int N30; vecto…

mysql_init和mysql_real_connect的形象化认识

解析总结 1. mysql_init 的作用 mysql_init 用于初始化一个 MYSQL 结构体,为后续数据库连接和操作做准备。该结构体存储连接配置及状态信息,是 MySQL C API 的核心句柄。 示例: MYSQL *conn mysql_init(NULL); // 初始化连接句柄2. mysql_…

C++中常用的排序方法之——冒泡排序

成长路上不孤单😊😊😊😊😊😊 【14后😊///计算机爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于C中常用的排序方法之——冒泡排序的…

ARM嵌入式学习--第十天(UART)

--UART介绍 UART(Universal Asynchonous Receiver and Transmitter)通用异步接收器,是一种通用串行数据总线,用于异步通信。该总线双向通信,可以实现全双工传输和接收。在嵌入式设计中,UART用来与PC进行通信,包括与监控…

解锁微服务:五大进阶业务场景深度剖析

目录 医疗行业:智能诊疗的加速引擎 电商领域:数据依赖的破局之道 金融行业:运维可观测性的提升之路 物流行业:智慧物流的创新架构 综合业务:服务依赖的优化策略 医疗行业:智能诊疗的加速引擎 在医疗行业迈…

基于Flask的旅游系统的设计与实现

【Flask】基于Flask的旅游系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统采用Python作为后端开发语言,结合前端Bootstrap框架,为用户提供了丰富…

《HelloGitHub》第 106 期

兴趣是最好的老师,HelloGitHub 让你对编程感兴趣! 简介 HelloGitHub 分享 GitHub 上有趣、入门级的开源项目。 github.com/521xueweihan/HelloGitHub 这里有实战项目、入门教程、黑科技、开源书籍、大厂开源项目等,涵盖多种编程语言 Python、…

一文讲解Java中的BIO、NIO、AIO之间的区别

BIO、NIO、AIO是Java中常见的三种IO模型 BIO:采用阻塞式I/O模型,线程在执行I/O操作时被阻塞,无法处理其他任务,适用于连接数比较少的场景;NIO:采用非阻塞 I/O 模型,线程在等待 I/O 时可执行其…

Linux——网络(tcp)

文章目录 目录 文章目录 前言 一、TCP逻辑 1. 面向连接 三次握手(建立连接) 四次挥手(关闭连接) 2. 可靠性 3. 流量控制 4. 拥塞控制 5. 基于字节流 6. 全双工通信 7. 状态机 8. TCP头部结构 9. TCP的应用场景 二、编写tcp代码函数…

Flutter使用Flavor实现切换环境和多渠道打包

在Android开发中通常我们使用flavor进行多渠道打包,flutter开发中同样有这种方式,不过需要在原生中配置 具体方案其实flutter官网个了相关示例(https://docs.flutter.dev/deployment/flavors),我这里记录一下自己的操作 Android …

MySQL备忘录

MySQL 的一些基础知识记录,包括一些配置文件,cmd命令等 前言 这里使用的MySQL版本是8.0.25 MySQL安装,包括相关配置文件文本内容,相关cmd命令 通过安装包配置环境变量使用cmd管理员权限通过命令安装MySQL 8.0.25 一、安装配置 …

Prompt提示词完整案例:让chatGPT成为“书单推荐”的高手

大家好,我是老六哥,我正在共享使用AI提高工作效率的技巧。欢迎关注我,共同提高使用AI的技能,让AI成功你的个人助理。 许多人可能会跟老六哥一样,有过这样的体验:当我们遇到一个能力出众或对事物有独到见解的…

Maui学习笔记- SQLite简单使用案例02添加详情页

我们继续上一个案例,实现一个可以修改当前用户信息功能。 当用户点击某个信息时,跳转到信息详情页,然后可以点击编辑按钮导航到编辑页面。 创建项目 我们首先在ViewModels目录下创建UserDetailViewModel。 实现从详情信息页面导航到编辑页面…

Linux文件原生操作

Linux 中一切皆文件,那么 Linux 文件是什么? 在 Linux 中的文件 可以是:传统意义上的有序数据集合,即:文件系统中的物理文件 也可以是:设备,管道,内存。。。(Linux 管理的一切对象…

HttpClient学习

目录 一、概述 二、HttpClient依赖介绍 1.导入HttpClient4依赖 2.或者导入HttpClient5依赖 3.二者区别 三、HttpClient发送Get请求和Post请求测试 (一)通过HttpClient发送Get请求 (二)通过HttpClient发送Post请求 一、概述 HttpClient是 Apache 软件基金会提供的一…

【重生之我在学习C语言指针详解】

目录 ​编辑 --------------------------------------begin---------------------------------------- 引言 一、指针基础 1.1 内存地址 1.2 指针变量 1.3 指针声明 1.4 取地址运算符 & 1.5 解引用运算符 *** 二、指针运算 2.1 指针加减运算 2.2 指针关系运算 三…

< OS 有关> BaiduPCS-Go 程序的 菜单脚本 Script: BaiduPCS-Go.Menu.sh (bdgo.sh)

目标: 使用 日本阿里云的 VPM 传输文件。 暂时方案: 使用 主机JPN 下载 https://huggingface.co/ 上模型从 JPN 放到 度狗上在家里从狗度下载 为了减少编程,尽量使用现在软件 ,就找到 GitHub - qjfoidnh/BaiduPCS-Go: iikira…

98.1 AI量化开发:长文本AI金融智能体(Qwen-Long)对金融研报大批量处理与智能分析的实战应用

目录 0. 承前1. 简介1.1 通义千问(Qwen-Long)的长文本处理能力 2. 基础功能实现2.1 文件上传2.2 单文件分析2.3 多文件分析 3. 汇总代码&运行3.1 封装的工具函数3.2 主要功能特点3.3 使用示例3.4 首次运行3.5 运行结果展示 4. 注意事项4.1 文件要求4.2 错误处理机制4.3 最佳…

Linux环境基础开发工具的使用(apt, vim, gcc, g++, gbd, make/Makefile)

目录 什么是软件包 Linux 软件包管理器 apt 认识apt 查找软件包 安装软件 如何实现本地机器和云服务器之间的文件互传 卸载软件 Linux编辑器 - vim vim的基本概念 vim下各模式的切换 vim命令模式下各指令汇总 vim底行模式个指令汇总 Linux编译器 - gcc/g gcc/g的作…