Pytorch 猫狗识别案例

猫狗识别数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88483483?spm=1001.2014.3001.5501

训练集图片路径

测试集图片路径

训练代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time
from torch.optim.lr_scheduler import StepLRif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 32# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)),  # 调整图像大小为 224x224torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.RandomRotation(45),torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化])dataset = torchvision.datasets.ImageFolder('./cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)# x,y = next(iter(val_dataset))# x = x.permute(1, 2, 0)  # 将通道维度调整到最后# x = (x - x.min()) / (x.max() - x.min())  # 反归一化操作# plt.imshow(x)  # 将通道维度调整到最后# plt.axis('off')  # 关闭坐标轴# plt.show()model = models.resnet34(weights=None)num_classes = 2model.fc = nn.Sequential(nn.Dropout(p=0.2),# nn.BatchNorm4d(model.fc.in_features),nn.Linear(model.fc.in_features, num_classes),nn.Sigmoid(),)lambda_L1 = 0.001lambda_L2 = 0.0001regularization_loss_L1 = 0regularization_loss_L2 = 0for name,param in model.named_parameters():param.requires_grad = Trueif 'bias' not in name:regularization_loss_L1 += torch.norm(param, p=1).detach()regularization_loss_L2 += torch.norm(param, p=2).detach()optimizer = optim.Adam(model.parameters(), lr=0.01)scheduler = StepLR(optimizer, step_size=5, gamma=0.9)criterion = nn.BCELoss().to(device)model.to(device)# print(model)loadfilename = "recognize_cats_and_dogs.pt"savefilename = "recognize_cats_and_dogs3.pt"checkpoint = torch.load(loadfilename)model.load_state_dict(checkpoint['model_state_dict'])def save_checkpoint(epoch, model, optimizer, filename, train_loss=0., val_loss=0.):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'train_loss': train_loss,'val_loss': val_loss,}torch.save(checkpoint, filename)num_epochs = 100train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)one_hot = nn.functional.one_hot(labels, num_classes).float()# 计算损失和梯度loss = criterion(outputs, one_hot) + lambda_L1 * regularization_loss_L1 + lambda_L2 * regularization_loss_L2loss.backward()if ((i + 1) % 2 == 0) or (i + 1 == len(train_dataset)):# 更新模型参数optimizer.step()optimizer.zero_grad()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)one_hot = nn.functional.one_hot(labels, num_classes).float()loss = criterion(outputs, one_hot)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timetain_loss = running_loss / len(train_dataset)val_loss = running_loss_test / len(val_dataset)print("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs, epoch_time, tain_loss,accuracy_train, val_loss, accuracy_test))save_checkpoint(epoch, model, optimizer, savefilename, tain_loss, val_loss)scheduler.step()# plt.plot(train_loss, label='Train Loss')# # 添加图例和标签# plt.legend()# plt.xlabel('Epochs')# plt.ylabel('Loss')# plt.title('Training Loss')## # 显示图形# plt.show()

测试代码如下

import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 32transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化])dataset = torchvision.datasets.ImageFolder('./cats_and_dogs_test',transform=transform)test_dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet34()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features,num_classes),nn.LogSoftmax(dim=1))model.to(device)# print(model)filename = "recognize_cats_and_dogs.pt"checkpoint = torch.load(filename)model.load_state_dict(checkpoint['model_state_dict'])class_name = ['cat','dog']# 在测试集上计算准确率with torch.no_grad():for inputs, labels in test_dataset:inputs, labels = inputs.to(device), labels.to(device)output = model(inputs)_, predicted = torch.max(output.data, 1)for x,y,z in zip(inputs,labels,predicted):x = (x - x.min()) / (x.max() - x.min())plt.imshow(x.cpu().permute(1,2,0))plt.axis('off')plt.title('predicted: {0}'.format(class_name[z]))plt.show()

部分测试结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/175827.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于静电放电算法的无人机航迹规划-附代码

基于静电放电算法的无人机航迹规划 文章目录 基于静电放电算法的无人机航迹规划1.静电放电搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要:本文主要介绍利用静电放电算法来优化无人机航迹规划。 …

0基础学习PyFlink——用户自定义函数之UDF

大纲 标量函数入参并非表中一行(Row)入参是表中一行(Row)alias PyFlink中关于用户定义方法有: UDF:用户自定义函数。UDTF:用户自定义表值函数。UDAF:用户自定义聚合函数。UDTAF&…

1400*C. Team(模拟构造)

Problem - 401C - Codeforces 解析&#xff1a; 因为0不能相邻&#xff0c;所以0之间最少 n-1 个位置&#xff0c;最多 n1 个位置&#xff0c;如果 m<n-1显然不符题意。 并且1最多连续两个&#xff0c;所以 m>2*n2 同样不符题意。 其余情况构造即可 #include<bits/st…

【嵌入式】【GIT】如何迁移老的GIF到新的仓库时使用LFS功能并保持LOG不变

一、正常迁移流程 假设有仓库 ssh://old/buildroot-201902 需要迁移到新的仓库 ssh://old/buildroot-201902时,我们可以使用以下命令来完成: # 下载老的仓库 git clone ssh://old/buildroot-201902 # 向新的仓库上传所有的tags git push ssh://new/buildroot-201902 --tag…

【网络安全】Seeker内网穿透追踪定位

Seeker追踪定位对方精确位置 前言一、kali安装二、seeker定位1、ngrok平台注册2、获取一次性邮箱地址3、ngrok平台登录4、ngrok下载5、ngrok令牌授权6、seeker下载7、运行seeker定位8、运行隧道开启监听9、伪装链接10、用户点击&#xff08;获取定位成功&#xff09;11、利用经…

Rust-虽然9天过去了,结果是没有结果(Docker容器的端口映射问题)

​ 这篇文章收录于Rust 实战专栏。这个专栏中的相关代码来自于我开发的笔记系统。它启动于是2023年的9月14日。相关技术栈目前包括&#xff1a;Rust&#xff0c;Javascript。关注我&#xff0c;我会通过这个项目的开发给大家带来相关实战技术的分享。 前言 上上周了吧&#xf…

机器学习(六)构建机器学习模型

1.9构建机器学习模型 我们使用机器学习预测模型的工作流程讲解机器学习系统整套处理过程。 整个过程包括了数据预处理、模型学习、模型验证及模型预测。其中数据预处理包含了对数据的基本处理&#xff0c;包括特征抽取及缩放、特征选择、特征降维和特征抽样&#xff1b;我们将…

Linux的简介和环境搭建

简介 Linux是一套免费使用和自由传播的类Unix操作系统&#xff0c;是一个基于POSIX和Unix的多用户、多任务、支持多线程和多CPU的操作系统。它能运行主要的Unix工具软件、应用程序和网络协议。它支持32位和64位硬件。Linux继承了Unix以网络为核心的设计思想&#xff0c;是一个…

Python构造代理IP池提高访问量

目录 前言 一、代理IP是什么 二、代理IP池是什么 三、如何构建代理 IP 池 1. 从网上获取代理 IP 地址 2. 对 IP 地址进行筛选 3. 使用筛选出来的 IP 地址进行数据的爬取 四、总结 前言 爬虫程序是批量获取互联网上的信息的重要工具&#xff0c;在访问目标网站时需要频…

#stm32整理(一)flash读写

以这篇未开始我将进行stm32学习整理为期一个月左右完成stm32知识学习整理内容顺序没有一定之规写到哪想到哪想到哪写到哪&#xff0c;主要是扫除自己知识上的盲区完成一些基本外设操作。 以stm32f07为例子进行flash读写操作 stm32flash简介 参考资料正点原子和野火开发手册 …

如何选择传感器输出模式——电流输出还是电压输出?

一 背景及挑战 传感器在汽车测试系统中发挥着采集和传输信息的作用&#xff0c;可称之为汽车的“神经元”。 按照功能可将传感器分为压力传感器、流量传感器、温湿度传感器和电流传感器等。传感器的主要指标是精度、测量范围和响应时间等。在满足指标的情况下&#xff0c;通常…

笔记软件Notability mac中文版软件功能

Notability mac是一款帮助用户备注文件的得力工具&#xff0c;Notability Mac版可用于注释文稿、草拟想法、录制演讲、记录备注等。它将键入、手写、录音和照片结合在一起&#xff0c;便于您根据需要创建相应的备注。 Mac Notability mac中文版软件功能 将手写&#xff0c;照片…

springboot和flask整合nacos,使用openfeign实现服务调用,使用gateway实现网关的搭建(附带jwt续约的实现)

环境准备&#xff1a; 插件版本jdk21springboot 3.0.11 springcloud 2022.0.4 springcloudalibaba 2022.0.0.0 nacos2.2.3&#xff08;稳定版&#xff09;python3.8 nacos部署&#xff08;docker&#xff09; 先创建目录&#xff0c;分别创建config&#xff0c;logs&#xf…

APISpace 全国快递物流地图轨迹查询API接口案例代码

1.全国快递物流地图轨迹查询接口详解 1.1 接口请求 请求方式&#xff1a;POST请求地址&#xff1a;https://eolink.o.apispace.com/wldtgj1/paidtobuy_api/trace_map请求头&#xff1a; 标签必填说明X-APISpace-Token是鉴权私钥&#xff0c;登陆 APISpace 后在管理后台的[访…

证照之星XE专业版下载专业证件照制作工具

值得肯定的是智能背景替换功能&#xff0c;轻松解决背景处理这一世界难题。不得不提及的是新增打印字体设置&#xff0c;包含字体选择、字号大小、字体颜色等。不同领域的应用证明了万能制作&#xff0c;系统支持自定义证照规格&#xff0c;并预设了17种常用的证件照规格。人所…

紧急:发现NGINX Ingress Controller for Kubernetes中的新安全漏洞

导语 大家好&#xff0c;今天我要向大家紧急报告一则消息&#xff1a;我们在NGINX Ingress Controller for Kubernetes中发现了三个新的安全漏洞&#xff01;这些漏洞可能被黑客利用&#xff0c;从集群中窃取机密凭据。在本文中&#xff0c;我们将详细介绍这些漏洞的细节&#…

ROCKCHIP ~ Camera 闪光灯

一、闪光灯基本原理 工作模式 Camera flash led分flash和torch两种模式。 flash&#xff1a; 拍照时上光灯瞬间亮一下&#xff0c;电流比较大&#xff0c;目前是1000mA&#xff0c;最大电流不能超过led最大承受能力 torch&#xff1a; 只用于录video或者拿led当手电筒的情况&…

python多环境并存

1. 现况简介 1.1 本人windows所存Python版本 Python 2.7 Python 3.6 Python 3.7 1.2 Python 各版本路径如下 Python 2.7Python 3.6Python 3.7C:\Server\Python27C:\Server\Python36C:\Server\Python37 1.3 系统环境变量配置如下 2. 解决方案 2.1 进入目录 cd C:\Server…

qt5工程打包成可执行exe程序

一、编译生成.exe 1.1、在release模式下编译生成.exe 1.2、建一个空白文件夹package&#xff0c;再将在release模式下生成的.exe文件复制到新建的文件夹中package。 1.3、打开QT5的命令行 1.4、用命令行进入新建文件夹package&#xff0c;使用windeployqt对生成的exe文件进行动…

84.在排序数组中查找元素的第一个和最后一个位置(力扣)

目录 问题描述 代码解决以及思想 知识点 问题描述 代码解决以及思想 class Solution { public:vector<int> searchRange(vector<int>& nums, int target) {int left 0; // 定义左边界int right nums.size() - 1; // 定义右…