4.训练篇2-毕设篇

resnet

# 1. 从 torchvision 中加载预训练的 ResNet18 模型
# pretrained=True 表示使用在 ImageNet 上预训练过的参数,学习效果更好
base_model_resnet18 = models.resnet18(pretrained=True)# 2. 获取 ResNet18 模型中全连接层(fc)的输入特征数
# 这是为了方便替换成我们自己任务的输出类别数
num_ftrs = base_model_resnet18.fc.in_features# 3. 替换原来的全连接层
# 原本的 fc 层是用来预测 1000 类(ImageNet),现在我们改成自己项目的 num_classes 类
# 比如 ASL 手势识别是 29 类,就写 nn.Linear(num_ftrs, 29)
base_model_resnet18.fc = nn.Linear(num_ftrs, num_classes)# 4. 把模型移动到 GPU 或 CPU 上进行训练
# device 变量一般是提前设置好的,比如:device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_resnet18 = base_model_resnet18.to(device)# 5. 使用模型对一批图像做预测
# 假设 b_img_rgb 是一个 batch 的图像张量(例如大小为 [64, 3, 224, 224]),表示 64 张 RGB 图
# .to(device) 表示把图像也放到和模型相同的设备上(GPU/CPU)
# 调用模型,相当于做前向传播,输出每张图片在 29 个类别上的得分(logits)
base_model_resnet18(b_img_rgb.to(device)).shape

输出torch.Size([64, 29])
表示:模型为每张图像输出了一个 长度为 29 的向量,每个数字代表这张图在某一类上的预测得分(不是概率,还没 softmax)。

base_model_resnet18 = models.resnet18(pretrained=True)

创建一个基于 ResNet18 的图像分类模型,并加载预训练参数,把它存到变量 base_model_resnet18中

loss_fn = nn.CrossEntropyLoss()

# 定义一个“交叉熵损失函数”(Cross Entropy Loss)
# 这个函数专门用于“分类问题”(比如手势识别有 29 个类别)
# 它会比较:
#   - 模型输出的预测结果(如:[0.1, 0.2, ..., 0.05])
#   - 和真实的标签(如:第 3 类)
# 然后计算两者差距,差距越小越好,模型就越准确。
 

optimizer = torch.optim.SGD(base_model_resnet18.parameters(), lr=1e-3)
# 定义一个优化器,用来更新模型参数,让 loss 更小
# 使用的是 “随机梯度下降(SGD)” 优化方法
# 参数解释:
#   base_model_resnet18.parameters():告诉优化器要优化哪些参数(就是模型的全部参数)
#   lr=1e-3:学习率(learning rate),表示每次更新的步子有多大,这里是 0.001

loss_fn = nn.CrossEntropyLoss()定义分类任务用的损失函数,用来衡量“模型预测”和“真实标签”的差距
optimizer = torch.optim.SGD(...)定义优化器,训练过程中帮你更新模型参数,让模型学得更好

# 设置训练轮数(epoch 表示:把整个训练集过一遍)
epochs = 25  # 一共训练 25 轮# 创建空列表,用来保存每一轮的训练/测试损失和准确率(后面可以画图)
train_loss_list = []  # 存每一轮训练集的 loss
train_acc_list = []   # 存每一轮训练集的准确率
test_loss_list = []   # 存每一轮测试集的 loss
test_acc_list = []    # 存每一轮测试集的准确率# 开始训练循环,共执行 epochs 次
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")  # 打印当前是第几轮训练# ---------- 训练模型 ----------# 调用你自定义的 train() 函数,执行一轮训练# 它会对 base_model_resnet18 模型进行训练,使用指定的 loss 函数和优化器train(train_dataloader, base_model_resnet18, loss_fn, optimizer)# ---------- 评估模型 ----------# 在训练集上测试模型效果,获取当前的 loss 和 正确率train_loss, train_correct = test(train_dataloader, base_model_resnet18, loss_fn)# 在测试集(验证集)上测试模型效果,获取 loss 和 正确率test_loss, test_correct = test(test_dataloader, base_model_resnet18, loss_fn)# ---------- 保存数据 ----------# 把每一轮的损失和准确率保存到列表中,后面可以画图分析训练效果train_loss_list.append(train_loss)train_acc_list.append(train_correct)test_loss_list.append(test_loss)test_acc_list.append(test_correct)# 所有训练轮次完成
print("Done!")

# 把训练集准确率记录保存为 acc
acc = train_acc_list# 把测试集准确率记录保存为 val_acc(val 表示 validation)
val_acc = test_acc_list# 把训练集损失记录保存为 loss
loss = train_loss_list# 把测试集损失记录保存为 val_loss
val_loss = test_loss_list# 创建一个迭代次数(epoch)的范围,比如 range(25) 表示从 0 到 24
epochs_range = range(epochs)# 设置画布大小为 8x8 英寸
plt.figure(figsize=(8, 8))# 画第一个子图(1行2列的第1个图):准确率曲线
plt.subplot(1, 2, 1)  # 行数=1,列数=2,这是第1个图
plt.plot(epochs_range, acc, label='Training Accuracy')       # 训练集准确率折线图
plt.plot(epochs_range, val_acc, label='Validation Accuracy') # 验证集准确率折线图
plt.legend(loc='lower right')  # 设置图例显示在右下角
plt.title('Training and Validation Accuracy')  # 设置标题# 画第二个子图(1行2列的第2个图):损失曲线
plt.subplot(1, 2, 2)  # 行数=1,列数=2,这是第2个图
plt.plot(epochs_range, loss, label='Training Loss')       # 训练集损失折线图
plt.plot(epochs_range, val_loss, label='Validation Loss') # 验证集损失折线图
plt.legend(loc='upper right')  # 图例显示在右上角
plt.title('Training and Validation Loss')  # 设置标题# 显示整个图像
plt.show()

# 创建空列表:用于保存最终的预测标签、真实标签、预测概率
predict_list = []        # 保存预测标签(整数类编号)
label_list = []          # 保存真实标签(整数类编号)
predict_pro_list = []    # 保存预测概率(softmax 后的概率)# 创建 softmax 层,将模型的输出 logits 转换为概率分布(每一类的可能性)
m_softmax = nn.Softmax(dim=1)  # dim=1 表示在每一行上做 softmax(对每张图片的输出做 softmax)# 遍历测试数据集中的每一个 batch(图像+真实标签)
for (img_rgb, y) in test_dataloader:# 把图像和标签送到和模型一样的设备上(CPU 或 GPU)img_rgb = img_rgb.to(device)y = y.to(device)# 模型对图像进行预测,输出的是“原始得分”(logits)predict_score = base_model_resnet18(img_rgb)# 将原始得分用 softmax 转换为概率predict_pro = m_softmax(predict_score)  # 每张图会得到一个 shape=[num_classes] 的概率向量# 使用 numpy 的 argmax,取概率最大值对应的类别编号作为“预测标签”predict_label = np.argmax(predict_score.detach().cpu().numpy(), axis=1)# 把每个 batch 的 softmax 概率保存到列表中predict_pro_list.append(predict_pro.detach().cpu().numpy())# 把预测标签保存到列表中predict_list.append(predict_label)# 把真实标签也保存到列表中(用于后面比较准确率等)label_list.append(y.detach().cpu().numpy())# 将预测的概率拼接成一个大矩阵(np.vstack 是垂直拼接)
# 然后取第 2 列([:,1]),表示预测为“第2类(index=1)”的概率 —— 适用于二分类
predict_pro_array = np.vstack(predict_pro_list)[:, 1]# 将预测标签列表拼接成一维数组(从多个 batch 拼起来)
predict_array = np.hstack(predict_list)# 将真实标签列表也拼接成一维数组
label_array = np.hstack(label_list)# 打印前 5 个预测概率、预测标签、真实标签,看看模型表现
predict_pro_array[:5], predict_array[:5], label_array[:5]

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/43350.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电磁兼容EMC概述

最近重新学了下电磁兼容,对这个东西更清晰了一些,就重新写了一篇,有不足的地方欢迎的大家在评论区里和我交流。 电磁兼容 电磁兼容指的是什么呢?指的是设备在其电磁环境中性能不受降级地正常运行并不对其他设备造成无法承受的电…

坚持“大客户战略”,昂瑞微深耕全球射频市场

北京昂瑞微电子技术股份有限公司(简称“昂瑞微”)是一家聚焦射频与模拟芯片设计的高新技术企业。随着5G时代的全面到来,智能手机、智能汽车等终端设备对射频前端器件在通信频率、多频段支持、信道带宽及载波聚合等方面提出了更高需求&#xf…

AI赋能职教革新:生成式人工智能(GAI)认证重构技能人才培养新范式

在数字化浪潮的推动下,职业教育正经历着前所未有的变革。面对快速变化的市场需求和技术发展,如何培养具备高技能、高素质的人才成为了职业教育的重要课题。而在这个过程中,人工智能(AI)技术的融入,无疑为职…

Python:日志管理器配置

日志模块组件: 日志器logger:提供应用程序调用的接口 处理器handler:将日志发送到指定的位置 过滤器filter:过滤日志信息 格式器formatter:格式化输出日志 如何配置日志管理器: #导入模块 import log…

城电科技|零碳园区光伏太阳花绽放零碳绿色未来

近日,珠海城电科技自主研发生产的三轴跟踪光伏太阳花在长沙某智慧零碳园区完成安装调试,正式投入运营。作为集“科技能源艺术”于一体的新能源太阳能光伏发电设备,这一创新艺术光伏景观不仅为园区注入绿色动能,更凭借独特的科技美…

c++ - 右击一个cpp文件,但是编译菜单项是灰的

文章目录 c - 右击一个cpp文件,但是编译菜单项是灰的概述END c - 右击一个cpp文件,但是编译菜单项是灰的 概述 VS2019, 整理工程,在编译,工程报错,说有个函数的实现没找到。 有实现part_opt.cpp,头文件也…

29_项目

目录 http.js 1、先注册账号 register.html 2、再登录 login.html 3、首页 index.html 4 详情 details.html cart.html css index.css register.css details.css 演示 进阶 http.js let baseURL "http://localhost:8888"; let resgiterApi baseURL &…

vmware 创建win10 系统,虚拟机NAT网络设置

虚拟机设置: 物理机本机创建桥接: 如何创建桥接,请自行脑补~

API 请求需要证书认证? 如何在 Postman 中正确配置和使用?

本文来介绍 Postman 提供的管理证书功能如何配置,要了解更多相关的知识,可访问 Postman 证书 模块。 管理客户端证书,点击对应的按钮,首先选择 SETTINGS ,然后选择 Certificate 选项卡,如图所示&#xff1…

强大的AI网站推荐(第四集)—— Gamma

网站:Gamma 号称:展示创意的新媒介 博主评价:快速展示创意,重点是展示,在几秒钟内快速生成幻灯片、网站、文档等内容 推荐指数:🌟🌟🌟🌟🌟&#x…

信息学奥赛一本通 1609:【例 4】Cats Transport | 洛谷 CF311B Cats Transport

【题目链接】 ybt 1609:【例 4】Cats Transport 洛谷 CF311B Cats Transport 【题目考点】 1. 动态规划:斜率优化动规 【解题思路】 解法1:设a点的前缀和 输入的 d d d序列是从 d 2 d_2 d2​到 d n d_n dn​,共n-1个数字。人…

从24GHz到71GHz:Sivers半导体的广泛频率范围5G毫米波产品解析

在5G技术的浪潮中,Sivers半导体推出了创新的毫米波无线产品,为通信行业带来高效、可靠的解决方案。这些产品支持从24GHz到71GHz的频率,覆盖许可与非许可频段,适应高速、低延迟的通信场景。 5G通信频段的一点事儿及Sivers毫米波射频…

LocalDateTime序列化总结

版权说明: 本文由CSDN博主keep丶原创,转载请保留此块内容在文首。 原文地址: https://blog.csdn.net/qq_38688267/article/details/146703276 文章目录 1.背景2.序列化介绍常见场景关键问题 3.总体方案4.各场景实现方式WEB接口EasyExcelMybat…

分享一个Pyside6实现web数据展示界面的效果图

今天又是有问题直接找DS的一天,每日一问,今天我的问题是“怎么将pyside6生成的界面转成web界面,使用python语言实现web界面”,等了一会,DS给我提供了两种方案,方案如下: 然后,让我们…

GAMMA数据处理(十)

今天向别人请教了一个问题,刚无意中搜索到了一模一样的问题 不知道这个怎么解决... ok 解决了 有一个GAMMA的命令可转换 但是很奇怪 完全对不上 转换出来的行列号 不知道为啥 再试试 是因为经纬度坐标的小数点位数 de as

[从零开始学习JAVA ] 深入多线程

前言: 当今软件开发领域中,多线程编程已成为一项至关重要的技能。然而,要编写出高效、可靠的多线程程序并不容易。多线程编程面临着许多挑战,如线程安全性、资源共享、死锁等问题。因此,对于初学者来说,深入…

【Python NetworkX】图结构 图绘制

【Python NetworkX】图结构 & 图绘制 1. 简介 & 安装1.1 简介1.2 安装1.3 导入 2. 图2.1 无向图2.2 有向图2.3 重边无向图2.4 重边有向图2.5 图属性 3. 节点3.1 添加节点3.2 移除节点3.3 节点属性3.4 检查节点状态 4. 边4.1 添加边4.2 移除边4.3 边属性4.4 检查边状态 …

Kubernetes》k8s》Containerd 、ctr 、cri、crictl

containerd ctr crictl ctr 是 containerd 的一个客户端工具。 crictl 是 CRI 兼容的容器运行时命令行接口,可以使用它来检查和调试 k8s 节点上的容器运行时和应用程序。 ctr -v 输出的是 containerd 的版本, crictl -v 输出的是当前 k8s 的版本&#x…

【湖北工业大学2025年ACM校赛(同步赛)】题解

比赛链接 A. 蚂蚁上树 题目大意 给定一棵 n n n 个结点的树,根结点为 1 1 1。每个 叶结点 都有一只蚂蚁,每过 1 1 1 秒钟,你可以选一些蚂蚁往其 父结点 走一步,但是要求任意两只蚂蚁都不能在同一个 非根结点 上。 问至少要…

CS2 DEMO导入blender(慢慢更新咯)

流程:cs2-sourcefilmmaker-blender 工具:cs2tools,cs2manager,blender,blender插件sourceio,source2viewer 导入sfm 工具界面 选择这个 sourceio插件 sourceIO其中新版本导入相机路径不见了&#xff0c…