实验13 使用预训练resnet18实现CIFAR-10分类

1.数据预处理

首先利用函数transforms.Compose定义了一个预处理函数transform,里面定义了两种操作,一个是将图像转换为Tensor,一个是对图像进行标准化。然后利用函数torchvision.datasets.CIFAR10下载数据集,这个函数有四个常见的初始化参数:root为数据存储的路径,如果数据已经下载,会直接从这个路径加载数据。train如果为True,表示加载训练集,train如果为False,加载测试集。download如果设置为True,表示如果本地不存在数据集,会自动从互联网上下载。transform指定一个转换函数,对数据进行预处理和数据增强等操作。所以下载训练集train_full时,train赋值为True,下载测试集时,train赋值为False。之后对下载的训练集train_full进行划分,先规定指定的大小,然后利用random_split进行划分,最后就是创建Dataloader,batch_size设为64,得到train_loader,val_loader,test_loader。

代码:

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 数据预处理和增强
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 图像标准化
])# 下载 CIFAR-10 数据集
train_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 划分训练集(40,000)和验证集(10,000)
train_size = int(0.8 * len(train_full))  # 80% 用于训练
val_size = len(train_full) - train_size  # 剩余 20% 用于验证
train_data, val_data = random_split(train_full, [train_size, val_size])# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

2.模型构建

模型构建就比较简单,直接使用使用pytorch定义的库函数,只有一行代码:

model = models.resnet18(pretrained=False),pretrained=False表示不使用在Imagenet上预训练的权重,pretrained=True表示使用在Imagenet上预训练的权重。因为这个模型是训练Imagenet构建的模型,要想让这个模型适应新任务,需要获取最后一层的输入特征数,然后利用一个全连接层将输出改为10。

代码:

# 初始化 ResNet-18 模型
model = models.resnet18(pretrained=True)
# 修改最后一层(全连接层),适应新的任务
num_ftrs = model.fc.in_features  # 获取最后一层的输入特征数
model.fc = torch.nn.Linear(num_ftrs, 10)  # 将输出改为 10 个类别(例如 CIFAR-10)

3.模型训练

创建Runner类,管理训练、评估、测试和预测过程。还是之前的一套东西,首先是一个init函数,用于初始化数据集、损失函数、优化器等。train函数用于计算在训练集上的loss,并反向传播更新参数。evaluate函数用于计算在验证集上的损失,不用反向传播更新模型的参数,同时根据evaluate函数得到的损失判断是否保存最优模型,利用state_dict函数保存最优模型。test函数首先加载最优模型,然后在测试集计算最优模型的准确率。predict函数预测某个图像属于某个类别的概率,虽然resnet最后一层没有softmax,但是也可以根据最后一层得到的10个logits(未经过归一化的原始输出)取最大来判断图像属于某一类(因为这10个值也是有大小关系的,softmax函数不会修改这10个值的大小关系)。

定义学习率=0.01、批次大小=30、损失函数为交叉熵损失nn.CrossEntropyLoss()、优化器为Adam。

实例化Runner,调用train函数,开始训练。

代码:

class Runner:def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, device):self.model = model.to(device)  # 将模型移到GPUself.train_loader = train_loaderself.val_loader = val_loaderself.test_loader = test_loaderself.criterion = criterionself.optimizer = optimizerself.device = deviceself.best_model = Noneself.best_val_loss = float('inf')self.train_losses = []  # 存储训练损失self.val_losses = []  # 存储验证损失def train(self, epochs=10):for epoch in range(epochs):self.model.train()running_loss = 0.0for inputs, labels in self.train_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs)loss = self.criterion(outputs, labels)loss.backward()self.optimizer.step()running_loss += loss.item()# 计算平均训练损失train_loss = running_loss / len(self.train_loader)self.train_losses.append(train_loss)# 计算验证集上的损失val_loss = self.evaluate()self.val_losses.append(val_loss)print(f'Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')# 如果验证集上的损失最小,保存模型if val_loss < self.best_val_loss:self.best_val_loss = val_lossself.best_model = self.model.state_dict()def evaluate(self):self.model.eval()val_loss = 0.0with torch.no_grad():for inputs, labels in self.val_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)outputs = self.model(inputs)loss = self.criterion(outputs, labels)val_loss += loss.item()return val_loss / len(self.val_loader)def test(self):self.model.load_state_dict(self.best_model)self.model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in self.test_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)outputs = self.model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_accuracy = correct / totalprint(f'Test Accuracy: {test_accuracy:.4f}')def predict(self, image):self.model.eval()image = image.to(self.device)  # 将图像移到GPUwith torch.no_grad():output = self.model(image)_, predicted = torch.max(output, 1)return predicted.item()def visualize_and_predict(self, index):"""针对训练集中的某一张图片进行预测,并可视化图片。:param index: 训练集中的图片索引"""# 获取训练集中的第 index 张图片image, label = self.train_loader.dataset[index]# 将图像移到GPU(如果需要)image = image.unsqueeze(0).to(self.device)  # 增加一个维度作为batch size# 可视化图像plt.imshow(image.cpu().squeeze().numpy(), cmap='gray')  # 假设是灰度图,若是彩色图像要调整plt.title(f"True Label: {label}")plt.show()# 预测该图片的类别predicted_label = self.predict(image)print(f"Predicted Label: {predicted_label}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 实例化Runner类
runner = Runner(model, train_loader, val_loader, test_loader, criterion, optimizer, device)# 训练模型
runner.train(epochs=30)
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(runner.train_losses, label='Train Loss')
plt.plot(runner.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid()
plt.show()

4.模型评价

调用test函数,计算在测试集上的准确率。

代码:

# 在最优模型上评估测试集准确率
runner.test()

5.模型预测

在训练集任意选取一个图像,获取图像的image和标签label,因为图像已经经过了transform的变换,所以这个图像不需要transform,只需要添加一个维度1作为batch_size,可视化图像和真实标签,然后调用predict函数进行预测,输出真实类别。

代码:

# CIFAR-10 是 RGB 图像,确保正确显示
# 将 Tensor 转换为 numpy 数组并调整维度顺序为 HWC (Height, Width, Channels)
image_np = image.numpy().transpose((1, 2, 0))  # 从 CHW 转为 HWC# 可视化图像
plt.imshow(image_np)
plt.title(f"True Label: {label}")
plt.show()# 直接将图像传递给预测函数,不再需要 transform
# 但是要确保图像传入时是正确的 batch size 形状,即增加一个 batch 维度
image_transformed = image.unsqueeze(0).to(device)  # 增加一个维度作为 batch size# 预测该图片的类别
predicted_label = runner.predict(image_transformed)
print(f"Predicted Label: {predicted_label}")

6.实验结果与分析

不使用预训练权重的损失变化、准确率和预测结果

使用预训练权重的损失变化、准确率和预测结果

通过观察损失变化,我们发现两个模型在训练集上的loss一直在减小,说明模型的参数一直在更新。但是在验证集上的损失一开始是下降的,但是后来不断增大,我觉得是因为模型过拟合了。但是可以发现在没有预训练权重上的最优验证损失是比有预训练权重的模型上的最优验证损失大的。通过保存最优模型,在最优模型上计算准确率,发现在没有预训练权重的模型得到的准确率是0.7332,在使用预训练权重的模型得到的准确率是0.7431。

结论:通过对比在验证集上的最优验证损失和在测试集上的准确率,得到结论使用了预训练的模型效果要更好。

7.总结与心得体会

总结:

1.预训练模型:

预训练模型是指在一个大规模数据集上(如 ImageNet、COCO 等)经过训练的模型。这个模型已经学习到了一些通用的特征,比如图像中的边缘、纹理、颜色、形状等,或者文本中的语法、词汇关系等。这些特征是从数据中自动学习的,并且在很多不同的任务中都有用。

例子:

在图像分类任务中,ResNet、VGG、Inception 等深度神经网络在 ImageNet 上经过训练后,它们可以识别成千上万种不同的物体。由于这些物体特征具有广泛的普适性,我们可以将这些模型用于其他图像分类任务(例如 Cifar-10、Cifar-100),而无需从头开始训练。

在自然语言处理(NLP)中,像 BERT、GPT 等预训练语言模型已经在大量的文本数据上训练过,学习了丰富的语言知识。因此,我们可以将这些模型应用于文本分类、情感分析、问答等任务。

预训练模型的优势:

节省计算资源:训练深度神经网络需要大量的计算资源和时间,尤其是在大规模数据集上。通过使用预训练模型,用户可以避免从零开始训练,直接利用现成的知识。

提高效果:预训练模型已经学习到了一些通用的特征,可以加速学习过程,并且通常能够取得比从头开始训练更好的效果。
2. 迁移学习(Transfer Learning)

迁移学习是一种利用在一个任务上学到的知识,来帮助在另一个相关任务上进行学习的技术。换句话说,它将一个任务中的学习成果迁移到另一个任务中,特别是在目标任务的数据较少时。

迁移学习的核心思想是:如果一个模型在某个任务上已经学到了一些有用的特征,那么这些特征可以迁移到另一个任务上,帮助模型更好地学习。

迁移学习的典型流程:

模型加载:加载一个在大数据集上预训练的模型(如 ResNet、VGG、BERT 等)。

模型微调:对模型的部分层进行微调,或者只训练新添加的层(如分类层)。

应用于新任务:将经过微调的模型应用于新的、可能较小的数据集。

迁移学习的类型

迁移学习有多种不同的方式,常见的有以下几种:

微调(Fine-Tuning):使用预训练模型的权重,并对某些层或整个模型进行微调,以适应新的任务和数据。

通常会冻结前几层(因为它们学习的是通用特征),只训练后几层(专门针对当前任务)。

特征提取(Feature Extraction):使用预训练模型的特征提取能力,将前几层的权重固定,不更新,仅训练新加的全连接层或输出层。

零-shot 学习:在一些任务中,预训练模型被直接应用到目标任务,而不进行微调,特别是当目标任务的标注数据非常少时。

迁移学习的应用:

计算机视觉:在一个大规模的数据集(如 ImageNet)上训练的模型可以用于许多不同的图像分类任务,例如识别猫、狗、车、飞机等物体,或者在医疗影像、无人驾驶等领域中应用。

自然语言处理(NLP):例如,BERT 和 GPT 等模型可以在情感分析、命名实体识别、机器翻译等任务上进行迁移学习。

3. 预训练模型和迁移学习的关系

预训练模型和迁移学习是紧密相关的。迁移学习通常依赖于预训练模型,使用在一个任务中学到的知识来帮助另一个任务。在迁移学习中,预训练模型提供了一个良好的起点,减少了从头开始训练的难度和所需的数据量。

预训练模型与迁移学习的关系:

预训练模型是迁移学习的基础,因为迁移学习的一个关键步骤是使用已经在其他任务上训练好的模型。

迁移学习则是使用这些预训练模型的技术,它通过微调或特征提取等方式,将预训练模型的知识应用到新任务中。

使用torchvision.datasets的常见参数:

root:数据存储的路径。如果数据已经下载,它会直接从该路径加载数据。

train:如果设置为 True,加载训练集;如果设置为 False,加载测试集。

download:如果设置为 True,如果本地不存在数据集,它会自动从互联网上下载。

transform:指定一个转换函数,对数据进行预处理和数据增强等操作。

transforms.Compose 是 torchvision.transforms 模块中的一个函数,用于将多个图像预处理操作组合成一个复合操作。在神经网络训练中,常常需要对输入图像进行多种预处理,例如将图像转换为张量(Tensor)、标准化、数据增强等。transforms.Compose 允许你将这些操作按顺序组合在一起,并一次性应用于输入图像。

心得体会:

这个实验直接调用预训练的resnet18进行CIFAR-10数据集的分类,因为这个模型是在Imagenet数据集上训练得到的,所以适用于新的任务需要微调模型。通过对比没有预训练权重的模型和有预训练权重的模型的训练效果,发现还是有预训练权重得到的结果比较好,因为预训练模型已经学习到了一些通用的特征,可以加速学习过程,通常能够取得比从头开始训练更好的效果。在实际应用中在理解模型内部实现的基础上,直接调用高层API是一个不错的选择,可以减少代码量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/483239.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【AI系统】代数简化

代数简化 代数简化&#xff08;Algebraic Reduced&#xff09;是一种从数学上来指导我们优化计算图的方法。其目的是利用交换率、结合律等规律调整图中算子的执行顺序&#xff0c;或者删除不必要的算子&#xff0c;以提高图整体的计算效率。 代数化简可以通过子图替换的方式完…

多人聊天室项目 BIO模型实现

BIO模型聊天室项目大体设计 BIO编程模型 Acceptor是服务器端负责监听具体端口的Socket每有一个客户端Client连接到服务器端&#xff0c;Acceptor就创建一个新的线程Handler来处理客户端发送的消息每一个客户端都有一个唯一的Handler来对应处理其事务为保证线程安全&#xff0c…

腾讯云平台 - Stable Diffusion WebUI 下载模型

1&#xff09;进入控制台&#xff0c;点击算力连接 》 JupyterLab 2&#xff09;进入模型目录&#xff08;双击&#xff09; 3&#xff09;上传模型 例如&#xff1a;我要上传大模型

夜神模拟器+Charles+postern+Mgisk+TrustMeAlready实现抓包

[实测有用]夜神模拟器CharlesposternMgiskTrustMeAlready实现抓包 PS:此贴仅做为技术交流,禁止非法用途。 1.初始化条件 A.安装MUMU模拟器安卓12版本 B.按图示选择&#xff0c;设置好代理端口8889 C.查看本机IP地址 D.导出证书&#xff0c;安装配置&#xff0c;暂时保存…

【closerAI ComfyUI】物体转移术之图案转移,Flux三重控制万物一致性生图,实现LOGO和图案的精准迁移

更多AI前沿科技资讯,请关注我们:closerAI-一个深入探索前沿人工智能与AIGC领域的资讯平台 closerAIGCcloserAI,一个深入探索前沿人工智能与AIGC领域的资讯平台,我们旨在让AIGC渗入我们的工作与生活中,让我们一起探索AIGC的无限可能性! 【closerAI ComfyUI】物体转移术之图…

新质驱动·科东软件受邀出席2024智能网联+低空经济暨第二届湾区汽车T9+N闭门会议

为推进广东省加快发展新质生产力&#xff0c;贯彻落实“百县千镇万村高质量发展工程”&#xff0c;推动韶关市新丰县智能网联新能源汽车、低空经济与数字技术的创新与发展&#xff0c;充分发挥湾区汽车产业链头部企业的带动作用。韶关市指导、珠三角湾区智能网联新能源汽车产业…

vue+mars3d给影像底图叠加炫酷效果

话不多说&#xff0c;直接上效果图&#xff1a; 在这里墙体其实是有一个不太明显的流动效果 实现方式&#xff1a;这里我使用了PolylineEntityWallPrimitive&#xff0c;开始我用的是polygonEntity但是发现实现效果并不好&#xff0c;所有直接改用了线 map.vue文件&#xff1…

【模电】常见电路参数计算

1.恒流源输出电阻 2.射极电压跟随器输出电阻 3.差分放大电路 3.1差模特性 3.1.1差模输入电阻Rid 3.1.2差模输出电阻Ro 3.1.3差模电压增益Avd 3.2共模特性 3.2.1共模输入电阻Ric 3.2.2共模电压增益Avc 4.组合放大电路 4.1单级放大器 4.1.1微变等效电路 4.1.1.1共射级 4.1.…

Linux-虚拟环境

文章目录 一. 虚拟机二. 虚拟化软件三. VMware WorkStation四. 安装CentOS操作系统五. 在VMware中导入CentOS虚拟机六. 远程连接Linux系统1. Finalshell安装2. 虚拟机网络配置3. 连接到Linux系统 七. 虚拟机快照 一. 虚拟机 借助虚拟化技术&#xff0c;我们可以在系统中&#…

Kafka如何保证消息可靠?

大家好&#xff0c;我是锋哥。今天分享关于【Kafka如何保证消息可靠&#xff1f;】面试题。希望对大家有帮助&#xff1b; Kafka如何保证消息可靠&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Kafka通过多种机制来确保消息的可靠性&#xff0c;主要包…

ONVIF协议网络摄像机客户端使用gsoap获取RTSP流地址GStreamer拉流播放

什么是ONVIF协议 ONVIF&#xff08;开放式网络视频接口论坛&#xff09;是一个全球性的开放式行业论坛&#xff0c;旨在促进开发和使用基于物理IP的安全产品接口的全球开放标准。 ONVIF规范的目标是建立一个网络视频框架协议&#xff0c;使不同厂商生产的网络视频产品完全互通。…

javaweb_Day05

1.请求响应 1.1 概述 1.2 请求 1.2.1 请求参数 1.2.2 响应 2.分层解耦 2.1 三层架构 &#xff08;1&#xff09;代码分层 2.2 分层解耦 2.3 IOC&DI入门 &#xff08;1&#xff09;控制反转IOC &#xff08;2&#xff09;依赖注入DI &#xff08;3&#xff09;汇总 …

Stable Diffusion 3详解

&#x1f33a;系列文章推荐&#x1f33a; 扩散模型系列文章正在持续的更新&#xff0c;更新节奏如下&#xff0c;先更新SD模型讲解&#xff0c;再更新相关的微调方法文章&#xff0c;敬请期待&#xff01;&#xff01;&#xff01;&#xff08;本文及其之前的文章均已更新&…

[VUE]框架网页开发02-如何打包Vue.js框架网页并在服务器中通过Tomcat启动

在现代Web开发中&#xff0c;Vue.js已经成为前端开发的热门选择之一。然而&#xff0c;将Vue.js项目打包并部署到生产环境可能会让一些开发者感到困惑。本文将详细介绍如何将Vue.js项目打包&#xff0c;并通过Tomcat服务器启动运行。 1. 准备工作 确保你的项目能够正常运行,项…

网络分层模型( OSI、TCP/IP、五层协议)

1、网络分层模型 计算机网络是一个极其复杂的系统。想象一下最简单的情况&#xff1a;两台连接在网络上的计算机需要相互传输文件。不仅需要确保存在一条传输数据的通路&#xff0c;还需要完成以下几项工作&#xff1a; 发起通信的计算机必须激活数据通路&#xff0c;这包括发…

采药 刷题笔记 (动态规划)0/1背包

P1048 [NOIP2005 普及组] 采药 - 洛谷 | 计算机科学教育新生态 动态规划 0/1背包 的本质在于继承 一行一行更新 上一行是考虑前i个物品的最优情况 当前行是考虑第i1个物品的情况 当前行的最优解 来自上一行和前i个物品的最优解进行比较 如果当前装了当前物品&#xff…

汽车操作系统详解

目录 1. 车控汽车操作系统 2. 车载汽车操作系统 3. OEM定制操作系统 刚开始工作的时候&#xff0c;接触的是汽车控制相关的开发工作&#xff0c;天真地以为汽车操作系统就是指实时操作系统&#xff0c;例如FreeRTOS、OSEK OS、AUTOSAR OS等等&#xff1b;然而&#xff0c;随…

Shire 1.1 发布:更强大的交互支持,升级 AI 智能体与 IDE 的整合体验

在经过多个项目上的试用后&#xff0c;我们进入了持续的修修补补&#xff0c;以及功能的增强阶段。终于&#xff0c;我们发布了 Shire 1.1 版本&#xff0c;这个版本带来了更强大的交互支持&#xff0c; 多功能升级 AI 与 IDE 的整合体验。 交互&#xff1a;丰富与大量 IDE 插件…

Springboot(四十九)SpringBoot3整合jetcache缓存

上文中我们学习了springboot中缓存的基本使用。缓存分为本地caffeine缓存和远程redis缓存。现在有一个小小的问题,我想使用本地caffeine缓存和远程redis缓存组成二级缓存。还想保证他们的一致性,这个事情该怎么办呢? Jetcache框架为我们解决了这个问题。 ‌JetCache‌是一个…

学习笔记052——Spring Boot 自定义 Starter

文章目录 Spring Boot 自定义 Starter1、自定义一个要装载的项目2、创建属性读取类 ServiceProperties3、创建 Service4、创建自动配置类 AutoConfigration5、创建 spring 工程文件6、将项目打成 jar 包7、jar 打包到本地仓库8、配置application.yml Spring Boot 自定义 Starte…