(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

5. 读取数据

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, device):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(device[0]), labels.to(device[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None :valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201844.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法】链表-20231123

这里写目录标题 一、19. 删除链表的倒数第 N 个结点二、21. 合并两个有序链表三、24. 两两交换链表中的节点 一、19. 删除链表的倒数第 N 个结点 提示 中等 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 输入:head [1,…

Web实战:基于Django与Bootstrap的在线计算器

文章目录 写在前面实验目标实验内容1. 创建项目2. 导入框架3. 配置项目前端代码后端代码 4. 运行项目 注意事项写在后面 写在前面 本期内容:基于Django与Bootstrap的在线计算器 实验环境: vscodepython(3.11.4)django(4.2.7)bootstrap(3.4.1)jquery(3…

【计算机网络笔记】路由算法之层次路由

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

推荐一款适合做智慧旅游的前端模板

目录 前言 一、功能介绍 二、前端技术介绍 三、功能及界面设计介绍 1、数据概览 2、车辆监控 3、地图界面 4、其它功能 四、扩展说明 总结 前言 智慧旅游是一种全新的旅游业务模式,它充分利用先进的信息技术,提升旅游体验,优化旅游管…

【JavaEE】Spring更简单的存储和获取对象(类注解、方法注解、属性注入、Setter注入、构造方法注入)

一、存储Bean对象 在这篇文章中我介绍了Spring最简单的创建和使用:Spring的创建和使用 其中存储Bean对象是这样的: 1.1 配置扫描路径 想要成功把对象存到Spring中,我们需要配置对象的扫描包路径 这样的话,就只有被配置了的包…

shell 条件语句

目录 测试 test测试文件的表达式 是否成立 格式 选项 比较整数数值 格式 选项 字符串比较 常用的测试操作符 格式 逻辑测试 格式 且 (全真才为真) 或 (一真即为真) 常见条件 双中括号 [[ expression ]] 用法 &…

从0开始学习JavaScript--JavaScript迭代器

JavaScript迭代器(Iterator)是一种强大的编程工具,它提供了一种统一的方式来遍历不同数据结构中的元素。本文将深入探讨JavaScript迭代器的基本概念、用法,并通过丰富的示例代码展示其在实际应用中的灵活性和强大功能。 迭代器的…

MFS分布式文件系统

目录 集群部署 Master Servers ​Chunkservers ​编辑Clients Storage Classes LABEL mfs高可用 pacemaker高可用 ​编辑ISCSI 添加集群资源 主机 ip 角色 server1 192.168.81.11 Master Servers server2 192.168.81.12 Chunkservers server3 192.168.81.13 Chunkserver…

前端 webpack 面试题

文章目录 webpack打包流程webpack声明周期自开发 webpack 插件loader和plugin的区别Loader(加载器):Plugin(插件):总结区别: webpack如何热启动及原理HMR(热更新实现的原理)websocketfs.watch 说说一些常用的loader和p…

[数据结构]—栈和队列

💓作者简介🎉:在校大二迷茫大学生 💖个人主页🎉:小李很执着 💗系列专栏🎉:数据结构 每日分享✨:到头来,有意义的并不是结果,而是我们度…

TCP/IP

分层模型 TCP 传输控制协议 UDP 用户数据包协议 四层 应用层 负责发送/接收消息 传输层 负责拆分和组装 .期间会有编号 网络层 TCP/UDP 属于网络层, 不会判断和处理编号 数据链路层 以太网 ,网络设备 TCP 连接 TCP连接需要端口,进行通信 Java 通过Socket 接收消息 发送 …

【码神之路】【Golang】博客网站的搭建【学习笔记整理 持续更新...】

介绍 一个用原生GO开发的博客网站,涉及Golang Web开发、Web服务器搭建和HTTP请求处理、模板与静态资源处理等 技术栈 后端:Go、Go并发机制前端:HTML模版链接直达 Golang搭建博客网站的学习视频 注:这里我只记录我实质✅学习到…

PDF Reader Pro 3.0.1.0(pdf阅读器)

PDF Reader Pro是一款功能强大的PDF阅读、注释、填写表单&签名、转换、OCR、合并拆分PDF页面、编辑PDF等软件。 它支持多种颜色的高亮、下划线,可以按需选择,没有空白处可以进行注释,这时候便签是你最佳的选择,不点开时自动隐…

Kubernetes容器状态探测的艺术

在Kubernetes集群中维护容器状态更像是一种艺术,而不是科学。原文: The Art and Science of Probing a Kubernetes Container[1] 在Kubernetes集群中维护容器状态更像是一种艺术,而不是科学。 本文将带你深入理解容器探测[2],并特别关注相对较…

【iOS】知乎日报

文章目录 前言一、首页1.网络的异步请求2.避免同一网络请求执行多次3.下拉刷新与上拉加载的实现下拉刷新上拉加载 二、网页1.webView的实现2.webView的滑动加载3.网页与首页内容的同步更新 三、评论区Masonory实现行高自适应 四、收藏中心通过FMDB实现数据持久化1.创建或打开数…

内存可见性与指令重排序

文章目录 内存可见性内存可见性问题代码演示JMM(Java Memory Model) 指令重排序指令重排序问题代码演示指令重排序分析 volatile关键字volatile 保证内存可见性 & 禁止指令重排序volatile 不保证原子性 在上一节介绍线程安全问题的过程中&#xff0c…

分享一篇很就以前的文档-VMware Vsphere菜鸟篇

PS:由于内容是很久以前做的记录,在整理过程中发现了一些问题,简单修改后分享给大家。首先ESXI节点和win7均运行在VMware Workstation上面,属于是最底层,而新创建的CentOS则是嵌套后创建的操作系统,这点希望…

Dubbo从入门到上天系列第十八篇:Dubbo引入注册中心简介以及DubboAdmin简要介绍,为后续详解Dubbo各种注册中心做铺垫!

一:Dubbo注册中心引言 1:什么是Dubbo的注册中心? Dubbo注册中心是Dubbo服务治理中极其重要的一个概念。它主要是用于对Rpc集群应用实例进行管理。 对于我们的Dubbo服务来讲,至少有两部分构成,一部分是Provider一部分是…

看不惯AI版权作品被白嫖!Stability AI副总裁选择了辞职,曾领导开发Stable Audio

近日,OpenAI的各种大瓜真是让人吃麻了。 而就在Sam Altmam被开除前两天,可能没太多人注意到Stability AI副总裁Newton—Rex因看不惯StabilityAI在版权保护上的行为选择辞职一事。 大模型研究测试传送门 GPT-4传送门(免墙,可直接…

基于VM虚拟机下Ubuntu18.04系统,Hadoop的安装与详细配置

参考博客: https://blog.csdn.net/duchenlong/article/details/114597944 与上面这个博客几乎差不多,就是java环境配置以及后面的hadoop的hdfs-site.xml文件有一些不同的地方。 准备工作 1.更新 # 更新 sudo apt update sudo apt upgrade2.关闭防火…