基于Pytorch实现猫狗分类

文章目录

  • 一、环境配置
  • 二、数据集的准备
  • 三、猫狗分类的实例
  • 四、实现分类预测测试
  • 五、参考资料

一、环境配置

  1. 安装Anaconda
    具体安装过程,请自行百度
  2. 配置Pytorch
    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch
    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchvision
    

二、数据集的准备

  1. 数据集的下载
    kaggle网站的数据集下载地址:
    https://www.kaggle.com/lizhensheng/-2000
    百度网盘下载
    链接:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA
    密码:dmp4
  2. 数据集的分类
    将下载的数据集进行解压操作,然后进行分类
    分类如下(每个文件夹下包括cats和dogs文件夹)
    在这里插入图片描述

三、猫狗分类的实例

  1. 导入相应的库
    # 导入库
    import torch.nn.functional as F
    import torch.optim as optim
    import torch
    import torch.nn as nn
    import torch.nn.parallelimport torch.optim
    import torch.utils.data
    import torch.utils.data.distributed
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    
  2. 设置超参数
    # 设置超参数
    #每次的个数
    BATCH_SIZE = 20
    #迭代次数
    EPOCHS = 10
    #采用cpu还是gpu进行计算
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
  3. 图像处理与图像增强
     # 数据预处理transform = transforms.Compose([transforms.Resize(100),transforms.RandomVerticalFlip(),transforms.RandomCrop(50),transforms.RandomResizedCrop(150),transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
  4. 读取数据集和导入数据
    # 读取数据dataset_train = datasets.ImageFolder('E:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small\\train', transform)print(dataset_train.imgs)# 对应文件夹的labelprint(dataset_train.class_to_idx)dataset_test = datasets.ImageFolder('E:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small\\validation', transform)# 对应文件夹的labelprint(dataset_test.class_to_idx)# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)
    
  5. 定义网络模型
    # 定义网络
    class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3)self.max_pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, 3) self.max_pool2 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(64, 64, 3) self.conv4 = nn.Conv2d(64, 64, 3) self.max_pool3 = nn.MaxPool2d(2) self.conv5 = nn.Conv2d(64, 128, 3) self.conv6 = nn.Conv2d(128, 128, 3) self.max_pool4 = nn.MaxPool2d(2) self.fc1 = nn.Linear(4608, 512) self.fc2 = nn.Linear(512, 1)def forward(self, x): in_size = x.size(0) x = self.conv1(x) x = F.relu(x) x = self.max_pool1(x) x = self.conv2(x) x = F.relu(x) x = self.max_pool2(x) x = self.conv3(x) x = F.relu(x) x = self.conv4(x) x = F.relu(x) x = self.max_pool3(x) x = self.conv5(x) x = F.relu(x) x = self.conv6(x) x = F.relu(x)x = self.max_pool4(x) # 展开x = x.view(in_size, -1)x = self.fc1(x)x = F.relu(x) x = self.fc2(x) x = torch.sigmoid(x) return xmodellr = 1e-4# 实例化模型并且移动到GPUmodel = ConvNet().to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model.parameters(), lr=modellr)
    
  6. 调整学习率
    def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""modellrnew = modellr * (0.1 ** (epoch // 5)) print("lr:",modellrnew) for param_group in optimizer.param_groups: param_group['lr'] = modellrnew
    
  7. 定义训练过程
    # 定义训练过程
    def train(model, device, train_loader, optimizer, epoch):model.train() for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device).float().unsqueeze(1)optimizer.zero_grad()output = model(data)# print(output)loss = F.binary_cross_entropy(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))
    # 定义测试过程def val(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device).float().unsqueeze(1)output = model(data)# print(output)test_loss += F.binary_cross_entropy(output, target, reduction='mean').item()pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in output]).to(device)correct += pred.eq(target.long()).sum().item()print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    
  8. 定义保存模型和训练
    # 训练
    for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model, DEVICE, train_loader, optimizer, epoch) val(model, DEVICE, test_loader)torch.save(model, 'E:\\Cat_And_Dog\\kaggle\\model.pth')
    
    训练结果在这里插入图片描述

四、实现分类预测测试

  1. 准备预测的图片
  2. 进行测试
    from __future__ import print_function, division
    from PIL import Imagefrom torchvision import transforms
    import torch.nn.functional as Fimport torch
    import torch.nn as nn
    import torch.nn.parallel
    # 定义网络
    class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3)self.max_pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, 3)self.max_pool2 = nn.MaxPool2d(2)self.conv3 = nn.Conv2d(64, 64, 3)self.conv4 = nn.Conv2d(64, 64, 3)self.max_pool3 = nn.MaxPool2d(2)self.conv5 = nn.Conv2d(64, 128, 3)self.conv6 = nn.Conv2d(128, 128, 3)self.max_pool4 = nn.MaxPool2d(2)self.fc1 = nn.Linear(4608, 512)self.fc2 = nn.Linear(512, 1)def forward(self, x):in_size = x.size(0)x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = self.conv3(x)x = F.relu(x)x = self.conv4(x)x = F.relu(x)x = self.max_pool3(x)x = self.conv5(x)x = F.relu(x)x = self.conv6(x)x = F.relu(x)x = self.max_pool4(x)# 展开x = x.view(in_size, -1)x = self.fc1(x)x = F.relu(x)x = self.fc2(x)x = torch.sigmoid(x)return x
    # 模型存储路径
    model_save_path = 'E:\\Cat_And_Dog\\kaggle\\model.pth'# ------------------------ 加载数据 --------------------------- #
    # Data augmentation and normalization for training
    # Just normalization for validation
    # 定义预训练变换
    # 数据预处理
    transform_test = transforms.Compose([transforms.Resize(100),transforms.RandomVerticalFlip(),transforms.RandomCrop(50),transforms.RandomResizedCrop(150),transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])class_names = ['cat', 'dog']  # 这个顺序很重要,要和训练时候的类名顺序一致device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# ------------------------ 载入模型并且训练 --------------------------- #
    model = torch.load(model_save_path)
    model.eval()
    # print(model)image_PIL = Image.open('E:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small\\test\\cats\\cat.1500.jpg')
    #
    image_tensor = transform_test(image_PIL)
    # 以下语句等效于 image_tensor = torch.unsqueeze(image_tensor, 0)
    image_tensor.unsqueeze_(0)
    # 没有这句话会报错
    image_tensor = image_tensor.to(device)out = model(image_tensor)
    pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in out]).to(device)
    print(class_names[pred])
    
  3. 预测结果
    在这里插入图片描述
    在这里插入图片描述
    从实际训练的过程来看,整体看准确度不高。而经过测试发现,该模型只能对于猫进行识别,对于狗则会误判。

五、参考资料

Pytorch自定义模型实现猫狗分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/26872.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

猫狗训练单张图片的测试

猫狗训练的训练模型的建立,模型在整个预测集上的预测效果的测试的程序代码网上或一些书籍上都可查阅,但是对单张或某些图片的分类测试程序不多,这里通过参考博客:https://blog.csdn.net/baidu_35113561/article/details/79371716 …

宠物鼻纹识别及面部识别进一步在城市养犬登记场景落地

最近安阳狗咬人事件造成了极其恶劣的社会影响,大型禁养犬类伤人成为城市治安管理不容忽视的隐患,正威胁人们的生命安全,养犬热潮也给城市管理带来了不小的挑战,粗放式的养犬管理不再适应时代的需求,城市养犬管理改革已…

借助互联网,“宝贝它”欲打造线上宠物交易与服务平台

作为人类忠实的朋友,宠物一直伴随着很多家庭的成长。而随着人们生活节奏的不断加快,电子商务正成为越来越多传统垂直领域的解决方案,宠物交易与服务同样也不例外。上海创业公司宝贝它希望借助互联网,打造线上宠物交易及服务平台。…

小动物领养网站/宠物救助网站

摘 要 本论文对小动物领养网站的开发过程进行了较为详细的论述,采用B/S架构、ssm 框架和 java 开发的 Web 框架,eclipse开发工具。 小动物领养网站,主要的模块包括管理员;首页、个人中心、用户管理、动物展示管理、动物分类管理…

语音合成工具Coqui TTS安装及体验

先介绍两种免费的语音合成工具 balabolka 官网 http://balabolka.site/balabolka.htm 是一种基于微软Speech API (SAPI)的免费语音合成工具,只是简单的发音合成,效果比较生硬 Coqui TTS 官网 https://coqui.ai/ 是基于深度学习的语音合成软件&#x…

音视频进阶教程|如何实现游戏中的实时语音

1 游戏实时语音功能简介 1.1 游戏实时语音概念解释 范围:收听者接收音频的范围。方位:指收听者在游戏世界坐标中的位置和朝向,详情可参考 5.5 初始化设置 中的“步骤 1”。收听者:房间内接收音频的用户发声者:房间内…

通过实时语音驱动人像模拟真人说话

元宇宙的火热让人们对未来虚拟世界的形态充满了幻想,此前我们为大家揭秘了声网自研的 3D 空间音频技术如何在虚拟世界中完美模拟现实听觉体验,增加玩家沉浸感。今天我们暂时离开元宇宙,回到现实世界,来聊聊声网自研的 Agora Lipsy…

聊天语音APP开发|聊天语音软件开发-实时音视频技术

聊天语音软件的开发应该是一个以视频和语音直播为核心的社交系统。对于用户来说,更好的视频和语音直播功能可以增强用户的接受感,让用户持续使用。为了方便视频和语音直播的采用体验,减少直播的延时,聊天语音软件的开发将采用实时…

拿到offer提出离职,公司拖30天才放人,但下家公司等不了30天,怎么办?

拿到offer想跳槽,向公司提出了离职,但公司要拖30天才放人,新公司又等不了30天,offer可能就没有了,这就是一位网友面临的两难局面,这种情况有没有什么解决的好办法呢? 有人安慰楼主,下…

怎么说离职原因新的公司比较能接受?

怎么说离职原因新的公司比较能接受? 我来提供一些格式化的应对方法; 1.实际原因:原单位工资太少。离职原因:我认为我自己已经具备了一定的积累,希望可以迈向一个新的台阶。 2.实际原因:跟同事出不来。离…

我提了离职,公司给我涨薪了,还能待下去吗?

金三银四到了,相信不少同学又开始在物色新的公司。 不少同学反映,在提出离职后,公司给自己加了薪,虽然不多。 那“在职员工,提出辞职被挽留,应该留下吗?” 为什么想要离职? 这个问…

是的,我离职了

终于可以敞开说这件事情了,年后的这一个月,我彻底停更了,并不是偷懒了,而是我要找工作。大家也都知道18年的寒冬,很多大厂开始裁员,所以我要更加认真的学习,毕竟跟大厂出来的相比,自…

办理离职手续流程的详细流程(离职交接的标准流程)

1、正式员工办理离职手续流程 若员工自离,需提前一个月向部门领导提出辞职申请(即时聊天工具或邮件)和《解除劳动合同申请》。 1)面谈:一般领导都会先谈话,确定你离职的时间及安排交接人员进行工作交接。 2…

程序员新公司入职被拒 只因离职证明多了一句话!

程序猿(微信号:imkuqin) 猿妹 整编 新闻报道来自:成都商报 近日,成都一名程序员被新应聘的公司通知入职,然而因为原公司给他出具的一份离职证明上,记载了一句“该员工在项目未完成情况下因个人原…

提交辞职申请时,领导极力挽留,还答应加薪,要不要留下来?

提交辞职申请时,领导极力挽留,还答应加薪,要不要留下来?张工是一名程序员,最近他向领导提交了辞职申请表后,却被领导极力挽留,领导不仅打感情牌,还打加薪牌。就是希望张工能够留下来…

医学影像处理与识别,应用AI模型,探索疾病辅助诊断!

关注公众号,发现CV技术之美 今天(2023.1.9) arXiv.CV 上有7篇医学影像处理与识别相关论文。不过粗略看来,医学影像类的论文,很多都是直接使用已有模型(甚至都不是最先进的模型),加以…

【react从入门到精通】初识React

文章目录 人工智能福利文章前言React技能树什么是 React?安装和配置 React创建 React 组件渲染 React 组件使用 JSX传递属性(Props)处理组件状态(State)处理用户输入(事件处理)组合和嵌套组件写…

JWT续期问题,ChatGPT解决方案

JWT(JSON Web Token)通常是在用户登录后签发的,用于验证用户身份和授权。JWT 的有效期限(或称“过期时间”)通常是一段时间(例如1小时),过期后用户需要重新登录以获取新的JWT。然而&…

可用数据存量不足,还能怎样向AI模型注入人类智能?

作者 | 王昊 出品 | IDEA研究院 在深度学习发展的第三波浪潮中, ChatGPT引发了人们对人工智能前所未有的关注。它的出现意味着基于指令学习和人类反馈的AI技术成为人工智能领域的关键。然而,当前所展示的能力还远不是AI的最终形态,无论是产业…

玩转ChatGPT:基于Mucloud建立本地知识库

一、写在前面 人们普遍认为GPT有潜力颠覆教育行业,然而这种颠覆会以何种方式呈现呢? 在刘慈欣的科幻世界中,三体人拥有知识遗传的能力,这使得他们的技术迭代成本降至最低。然而,我们人类并未具备这样的特性&#xff…