pytorch训练和使用resnet

pytorch训练和使用resnet

使用 CIFAR-10数据集

训练 resnet

resnet-train.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim# 在CIFAR-10数据集中
# 训练集:包含50000张图像,用于训练模型。
# 测试集:包含10000张图像,用于评估模型的性能。
TRAIN_SIZE=50000
TEST_SIZE=10000# 批量大小
BATCH_SIZE=128# 数据预处理
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 使用预训练的ResNet模型 , 不从默认url下载预训练的模型
model = torchvision.models.resnet18(weights=None)
# 从当前路径加载预训练权重
model_path = './model/resnet18-f37072fd.pth'
model.load_state_dict(torch.load(model_path))# 修改最后一层以适应CIFAR-10的10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 将模型移到GPU(如果有)
if torch.cuda.is_available() :print('Using GPU')device = torch.device("cuda:0")
else :print('Using CPU')device = torch.device("cpu")   model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 训练网络
num_epochs = 50print('start Training')for epoch in range(num_epochs):model.train()running_loss = 0.0#总迭代次数 = 训练集大小 / 批量大小 =  向上取整(TRAIN_SIZE=50000 / BATCH_SIZE=128) = 391 次循环for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播 + 向后传播 + 优化outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印统计信息running_loss += loss.item()if i % 100 == 99:    # 每100个小批量打印一次print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')running_loss = 0.0# 更新学习率scheduler.step()print('Finished Training')# 测试网络
model.eval()
correct = 0
total = 0
with torch.no_grad():# 总迭代次数 = 测试集 / 批量大小 向上取整(TEST_SIZE=10000/BATCH_SIZE=128) = 79 次循环for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy_test = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy_test:.2f}%')# [Epoch 50, Batch 300] loss: 0.142
# Finished Training
# Accuracy of the network on the 10000 test images: 84.53%# 准确率>0.8保存模型
if(accuracy_test > 0.8):print("Accuracy  > 0.8 ,save model")model_path = './model/trained_resnet18_cifar10.pth'torch.save(model.state_dict(), model_path)print(f'Model saved to {model_path}')

使用训练后的 resnet

评估数据
1.jpeg :

请添加图片描述

2.jpeg:

请添加图片描述

restnet-eval.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image# 模型路径
model_path = './model/trained_resnet18_cifar10.pth'# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整图像大小为32x32transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])# 加载预训练的ResNet模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(model_path))
model.eval()  # 设置模型为评估模式# 将模型移到GPU(如果有)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)def predict_image(image_path):# 加载并预处理图像image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0)  # 添加批次维度image = image.to(device)# 进行预测with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs.data, 1)# 输出预测结果predicted_class = classes[predicted.item()]print(f'Predicted class: {predicted_class}')# img is in classes
predict_image('./data/1.jpeg')# img is not in classes
predict_image('./data/2.jpeg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/453741.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电影院订票选座小程序ssm+论文源码调试讲解

第2章 开发环境与技术 电影院订票选座小程序的编码实现需要搭建一定的环境和使用相应的技术,接下来的内容就是对电影院订票选座小程序用到的技术和工具进行介绍。 2.1 MYSQL数据库 本课题所开发的应用程序在数据操作方面是不可预知的,是经常变动的&…

处理文件上传和进度条的显示(进度条随文件上传进度值变化)

成品效果图&#xff1a; 解决问题&#xff1a;上传文件过大时&#xff0c;等待时间过长&#xff0c;但是进度条却不会动&#xff0c;只会在上传完成之后才会显示上传完成 上传文件的upload.component.html <nz-modal [(nzVisible)]"isVisible" [nzTitle]"文…

python包以及异常、模块、包的综合案例(较难)

1.自定义包 python中模块是一个文件&#xff0c;而包就是一个文件夹 有这个_init_.py就是python包&#xff0c;没有就是简单的文件夹 包的作用&#xff1a;当我们的模块越来越多时&#xff0c;包可以帮助我们管理这些模块&#xff0c;包的作用就是包含多个模块&#xff0c;但包…

【命令操作】信创终端系统上timedatectl命令详解 _ 统信 _ 麒麟 _ 方德

原文链接&#xff1a;【命令操作】信创终端系统上timedatectl命令详解 | 统信 | 麒麟 | 方德 Hello&#xff0c;大家好啊&#xff01;今天给大家带来一篇关于如何在信创终端系统上使用timedatectl命令的详细介绍。timedatectl 是Linux系统中非常实用的时间管理工具&#xff0c;…

学习写作--polyGCL.md

POLYGCL: GRAPH CONTRASTIVE LEARNING VIA LEARNABLE SPECTRAL POLYNOMIAL filters 这篇工作的摘要和引言写的特别好&#xff08;不愧是ICLR spotlight&#xff09; 摘要 第一步&#xff0c;设定背景 Recently, Graph Contrastive Learning (GCL) has achieved significantl…

使用Flask实现本机的模型部署

前言 模型部署是指将大模型运行在专属的计算资源上&#xff0c;使模型在独立的运行环境中高效、可靠地运行&#xff0c;并为业务应用提供推理服务。其目标是将机器学习模型应用于实际业务中&#xff0c;使最终用户或系统能够利用模型的输出&#xff0c;从而发挥其作用。 一、设…

12 django管理系统 - 注册与登录 - 登录

为了演示方便&#xff0c;我就直接使用models里的Admin来演示&#xff0c;不再创建用户模型了。 ok&#xff0c;先做基础配置 首先是在base.html中&#xff0c;新增登录和注册的入口 <ul class"nav navbar-nav navbar-right"><li><a href"/ac…

黑马软件测试第一篇_Linux

Linux 操作系统 说明: 所有硬件设备组装完成后的第⼀一层软件, 能够使⽤用户使⽤用硬件设备的软件 即为操作系统 常见分类 桌⾯面操作系统: Windows/macOS/Linux移动端操作系统: Android(安卓)/iOS(苹果)服务器器操作系统: Linux/Windows Server嵌⼊入式操作系统: Android(底…

linux线程 | 同步与互斥 | 线程池以及知识点补充

前言&#xff1a;本节内容是linux的线程的相关知识。本篇首先会实现一个简易的线程池&#xff0c; 然后再将线程池利用单例的懒汉模式改编一下。 然后再谈一些小的知识点&#xff0c;比如自旋锁&#xff0c; 读者写者问题等等。 那么&#xff0c; 现在开始我们的学习吧。 ps:本…

吴恩达深度学习笔记(6)

正交化 为了提高算法准确率&#xff0c;我们想到的方法 收集更多的训练数据增强样本多样性使用梯度下降将算法使算法训练时间更长换一种优化算法更复杂或者更简单的神经网络利用dropout 或者L2正则化改变网络框架更换激活函数改变隐藏单元个数 为了使有监督机制的学习系统良…

ansible playbooks

文章目录 一&#xff0c;ansible剧本二&#xff0c;ansible playbooks主要特性三&#xff0c;yaml基本语法规则四&#xff0c;剧本playbooks的组成结构五&#xff0c;yaml编写1.示例2.运行playbook2.1 运行2.2 检查yaml文件的语法是否正确2.3 检查tasks任务2.3 检查生效的主机2…

maven创建父子项目

创建父类 创建子模块 添加文件夹 配置tomcat 参考 然后启动项目即可 参考 https://blog.csdn.net/gjtao1130/article/details/115000022

Linux——shell 编程基础

基本介绍 shell 变量 环境变量&#xff08;也叫全局变量&#xff09; 位置参数变量 预定义变量 运算符 条件判断 流程控制 if 单分支&多分支 case 语句 for循环 while 循环 read 读取控制台输入 函数 系统函数 basename 获取文件名 dirname 获取目录路径 自定义函数 综…

DataWhale10月动手实践——Bot应用开发task03学习笔记

一、工作流 1. 工作流的定义 工作流由多个节点组成&#xff0c;这些节点可以包括大语言模型&#xff08;LLM&#xff09;、代码模块、逻辑判断工具、插件等。每个节点需要不同的信息来执行其功能。工作流的核心含义是&#xff1a;对工作流程及其操作步骤之间的业务规则进行抽…

中国信通院联合中国电促会开展电力行业企业开源典型实践案例征集

自2021年被首次写入国家“十四五”规划以来&#xff0c;开源技术发展凭借其平等、开放、协作、共享的优秀创作模式&#xff0c;正持续成为推动数字技术创新、优化软件生产模式、赋能传统行业转型升级、助力企业降本增效的重要引擎。电力是国民经济的重要基础性产业&#xff0c;…

开源神器!CodeFormer:一键去除马赛克,高清修复照片视频

❤️ 如果你也关注大模型与 AI 的发展现状&#xff0c;且对大模型应用开发非常感兴趣&#xff0c;我会快速跟你分享最新的感兴趣的 AI 应用和热点信息&#xff0c;也会不定期分享自己的想法和开源实例&#xff0c;欢迎关注我哦&#xff01; 微信公众号&#xff5c;搜一搜&…

Docker安装Mysql数据库

不同的应用程序可能依赖于不同版本的 MySQL 或具有不同的配置需求。通过 Docker&#xff0c;每个 MySQL 实例都可以运行在独立的容器中&#xff0c;与宿主机以及其他容器的环境相互隔离。这有效避免了因不同应用对 MySQL 版本、依赖库等方面的差异而导致的冲突。例如&#xff0…

盛元广通数字化实验动物中心LIMS综合管理系统

盛元广通数字化实验动物中心LIMS综合管理系统通过集成各种功能&#xff0c;从实验申请、伦理审批、笼位预约、动物采购到开展动物实验、数据归档等全流程智能化管理&#xff0c;保证了实验信息随时可查&#xff0c;管理可视化、流程简单化。实验动物中心采用电脑端、APP和微信小…

LangSplat和3D language fields简略介绍

LangSplat: 3D Language Gaussian Splatting 相关技术拆分解释&#xff1a; 3dgs&#xff1a;伟大无需多言SAM&#xff1a;The Segment Anything Model&#xff0c;是图像分割领域的foundational model&#xff0c;已经用在很多视觉任务上&#xff08;如图像修复、物体追踪、图…

Linux目录

一、虚拟机环境配置 1.安装虚拟机 安装步骤 新建虚拟机-->典型安装-->选择稍后安装操作系统-->选择系统类型和版本&#xff08;这里安装的是CentOS7 64位&#xff09;-->选择虚拟机文件路径&#xff08;建议每台虚拟机单独存放并且路径不要有中文&#xff09;--&…