基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试

基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试


1. 引言

在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。
在这里插入图片描述


2. 环境准备

首先,确保已安装以下 Python 库:

pip install torch torchvision pandas d2l
  • torch:PyTorch 核心库。
  • torchvision:提供计算机视觉相关的工具。
  • pandas:用于处理 CSV 文件。
  • d2l:深度学习工具库,提供辅助函数。

3. 数据准备

竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public

3.1 数据集结构

假设数据集位于 classify-leaves 目录下,包含以下文件:

classify-leaves/
├── train.csv
├── test.csv
├── images/├── image1.jpg├── image2.jpg...
  • train.csv:包含训练图像的路径和标签。
  • test.csv:包含测试图像的路径。

3.2 数据加载与预处理

import os
import pandas as pd
import randomimgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i
  • num2name:获取所有类别标签,并按类别数量排序。
  • name2num:将类别名称映射到数字编号。

4. 自定义数据集类

为了加载数据,我们需要定义一个自定义数据集类 Leaf_data

from torch.utils.data import Dataset
from d2l import torch as d2lclass Leaf_data(Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)
  • __getitem__:根据索引返回一个样本,包括图像和标签。
  • __len__:返回数据集的长度。

5. 模型定义与初始化

我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:

import torch
import torchvision
from torch import nndef init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
  • init_weight:使用 Xavier 初始化方法初始化全连接层的权重。
  • net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。

6. 训练过程

6.1 优化器与损失函数

lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
  • trainer:使用 Adam 优化器,对全连接层使用更高的学习率。
  • LR_con:使用余弦退火学习率调度器。
  • loss:使用交叉熵损失函数。

6.2 训练函数


def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc
  • train:训练模型,记录损失和准确率,并在验证集上评估模型。

7. 测试与结果保存

在测试集上进行预测,并保存结果到 CSV 文件:

net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
  • test_dataloader:加载测试数据。
  • res:保存预测结果到 CSV 文件。

8. 总结

本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:

  1. 自定义数据集类的实现。
  2. 使用预训练模型进行迁移学习。
  3. 训练模型并保存最佳模型。
  4. 在测试集上进行预测并生成提交文件。

希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊

完整代码

import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i# GPU 检查
def try_gpu():if torch.cuda.device_count() > 0:return torch.device('cuda')return torch.device('cpu')# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')def save_model(net):torch.save(net.state_dict(), model_path)# 自定义数据集类
class Leaf_data(Data.Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc# 模型初始化
def init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(dataset=Leaf_data(imgpath, True, augs),lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)

参考链接:

  • PyTorch 官方文档
  • torchvision 官方文档
  • d2l 深度学习工具库

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/16405.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

团结引擎 Shader Graph:解锁图形创作新高度

Shader Graph 始终致力于为开发者提供直观且高效的着色器构建工具&#xff0c;持续推动图形渲染创作的创新与便捷。在团结引擎1.4.0中&#xff0c;Shader Graph 迎来了重大更新&#xff0c;新增多项强大功能并优化操作体验&#xff0c;助力开发者更轻松地实现高质量的渲染效果与…

C# OpenCV机器视觉:模仿Halcon各向异性扩散滤波

在一个充满创意与挑战的图像处理工作室里&#xff0c;阿强是一位热情的图像魔法师。他总是在追求更加出色的图像效果&#xff0c;然而&#xff0c;传统的图像处理方法有时候并不能满足他的需求。 有一天&#xff0c;阿强听说了 Halcon 中的各向异性扩散滤波功能&#xff0c;它…

超详细的数据结构3(初阶C语言版)栈和队列。

文章目录 栈和队列1.栈1.1 概念与结构1.2 栈的实现 2. 队列2.1 概念与结构2.2 队列的实现 总结 栈和队列 1.栈 1.1 概念与结构 栈&#xff1a;⼀种特殊的线性表&#xff0c;其只允许在固定的⼀端进行插⼊和删除元素操作。进⾏数据插⼊和删除操作的⼀端称为栈顶&#xff0c;另…

利用邮件合并将Excel的信息转为Word(单个测试用例转Word)

利用邮件合并将Excel的信息转为Word 效果一览效果前效果后 场景及问题解决方案 一、准备工作准备Excel数据源准备Word模板 二、邮件合并操作步骤连接Excel数据源插入合并域预览并生成合并文档 效果一览 效果前 效果后 场景及问题 在执行项目时的验收阶段&#xff0c;对于测试…

一个基于ESP32S3和INMP441麦克风实现音频强度控制RGB灯带律动的代码及效果展示

一个基于ESP32S3和INMP441麦克风实现音频强度控制RGB灯带律动的代码示例&#xff0c;使用Arduino语言&#xff1a; 硬件连接 INMP441 VCC → ESP32的3.3VINMP441 GND → ESP32的GNDINMP441 SCK → ESP32的GPIO 17INMP441 WS → ESP32的GPIO 18INMP441 SD → ESP32的GPIO 16RG…

用户认证综合实验

实验需求 需求一&#xff1a;根据下表&#xff0c;完成相关配置 需求二&#xff1a;配置DHCP协议&#xff0c;具体要求如下 需求三&#xff1a;防火墙安全区域配置 需求四&#xff1a;防火墙地址组信息 需求五&#xff1a;管理员 为 FW 配置一个配置管理员。要求管理员可以通…

Curser2_解除机器码限制

# Curser1_无限白嫖试用次数 文末有所需工具下载地址 Cursor Device ID Changer 一个用于修改 Cursor 编辑器设备 ID 的跨平台工具集。当遇到设备 ID 锁定问题时&#xff0c;可用于重置设备标识。 功能特性 ✨ 支持 Windows 和 macOS 系统&#x1f504; 自动生成符合格式的…

linux部署node服务

1、安装nvm管理node版本 # 下载、解压到指定目录 wget https://github.com/nvm-sh/nvm/archive/refs/tags/v0.39.1.tar.gz tar -zxvf nvm-0.39.0.tar.gz -C /opt/nvm # 配置环境 vim ~/.bashrc~&#xff1a;这是一个路径简写符号&#xff0c;代表当前用户的主目录。在大多数 …

Kotlin实战经验:将接口回调转换成suspend挂起函数

在 Kotlin 协程中, suspendCoroutine 和 suspendCancellableCoroutine 是用于将回调或基于 future 的异步操作转换成挂起函数。 suspendCoroutine 用途:将回调式异步操作转换为可挂起函数 行为: 启动一个新的协程来处理基于回调的操作挂起当前协程,直到调用回调回调负责…

【DeepSeek服务器繁忙,请稍后再试...如何解决?】

DeepSeek服务器繁忙&#xff0c;请稍后再试...如何解决&#xff1f; DeepSeek该咋使用&#xff1f;解决办法&#xff1a;本地桌面工具接下来说下&#xff0c;DeepSeek提示词该咋写&#xff1f; DeepSeek该咋使用&#xff1f; 首先&#xff0c;先说下DeepSeek该咋使用&#xff…

SDKMAN! 的英文全称是 Software Development Kit Manager(软件开发工具包管理器)

文章目录 SDKMAN! 的核心功能SDKMAN! 的常用命令SDKMAN! 的优势总结 SDKMAN! 的英文全称是 Software Development Kit Manager。它是一个用于管理多个软件开发工具&#xff08;如 Java、Groovy、Scala、Kotlin 等&#xff09;版本的工具。SDKMAN! 提供了一个简单的方式来安装、…

Python实现GO鹅优化算法优化支持向量机SVM分类模型项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后关注获取。 1.项目背景 随着信息技术的迅猛发展&#xff0c;数据量呈爆炸式增长&#xff0c;如何从海量的数据中提取有价值…

网络安全工程师逆元计算 网络安全逆向

中职逆向题目整理合集 逆向分析&#xff1a;PE01.exe算法破解&#xff1a;flag0072算法破解&#xff1a;flag0073算法破解&#xff1a;CrackMe.exe远程代码执行渗透测试天津逆向re1 re22023江苏省re12023年江苏省赛re2_easygo.exe2022天津市PWN 逆向分析&#xff1a;PE01.exe …

Mysql 函数解析

文章目录 一、模糊匹配【like】二、CASE函数1、简单case2、搜索case3、搜索case 聚合函数 三、日期函数四、字符串处理 一、模糊匹配【like】 一般形式为&#xff1a;列名 [NOT] LIKE ‘%关键字%’&#xff0c;示例如下&#xff1a; like %北京%列名包括北京的字样like ‘北…

C# OpenCV机器视觉:SoftNMS非极大值抑制

嘿&#xff0c;你知道吗&#xff1f;阿强最近可忙啦&#xff01;他正在处理一个超级棘手的问题呢&#xff0c;就好像在一个混乱的战场里&#xff0c;到处都是乱糟糟的候选框&#xff0c;这些候选框就像一群调皮的小精灵&#xff0c;有的重叠在一起&#xff0c;让阿强头疼不已。…

2025届优秀大数据毕业设计

【2025计算机毕业设计】计算机毕业设计100个高通过率选题推荐&#xff0c;毕业生毕设必看选题指导&#xff0c;计算机毕业设计选题讲解&#xff0c;毕业设计选题详细指导_哔哩哔哩_bilibili 985华南理工大学学长 大厂全栈&#xff0c;大数据开发工程师 专注定制化开发

Visual Studio 进行单元测试【入门】

摘要&#xff1a;在软件开发中&#xff0c;单元测试是一种重要的实践&#xff0c;通过验证代码的正确性&#xff0c;帮助开发者提高代码质量。本文将介绍如何在VisualStudio中进行单元测试&#xff0c;包括创建测试项目、编写测试代码、运行测试以及查看结果。 1. 什么是单元测…

最新消息 | 德思特荣获中国创新创业大赛暨广州科技创新创业大赛三等奖!

2024年12月30日&#xff0c;广州市科技局公开第十三届中国创新创业大赛&#xff08;广东广州赛区&#xff09;暨2024年广州科技创新创业大赛决赛成绩及拟获奖企业名单&#xff0c;德思特获得了智能与新能源汽车初创组【第六名】【三等奖】的好成绩&#xff01; 关于德思特&…

DeepSeek模型架构及优化内容

DeepSeek v1版本 模型结构 DeepSeek LLM基本上遵循LLaMA的设计&#xff1a; 采⽤Pre-Norm结构&#xff0c;并使⽤RMSNorm函数. 利⽤SwiGLU作为Feed-Forward Network&#xff08;FFN&#xff09;的激活函数&#xff0c;中间层维度为8/3. 去除绝对位置编码&#xff0c;采⽤了…

内网ip网段记录

1.介绍 常见的内网IP段有&#xff1a; A类&#xff1a; 10.0.0.0/8 大型企业内部网络&#xff08;如 AWS、阿里云&#xff09; 10.0.0.0 - 10.255.255.255 B类&#xff1a;172.16.0.0/12 中型企业、学校 172.16.0.0 - 172.31.255.255 C类&#xff1a;192.168.0.0/16 家庭…