kaggle叶子分类比赛(易理解)

说实话网上很多关于叶子分类比赛的代码能取得的成绩都很好,但对于我这个业余人员太专业了,而且很多文章都有自己的想法,这让我这个仿写沐神代码的小菜鸡甚是头痛。
但好在我还是完成了,虽然结果并不是很好,但是如果跟着沐神走的同学在学习上应该没什么大问题。于是这篇文章的重点不是调参获得一个好成绩,而是把牵扯到的难点与思路好好的解释一下,方便同学们模仿。

竞赛地址:https://www.kaggle.com/c/classify-leaves

文章目录

  • 第一部分 加载并读取数据
  • 第二部分 定义网络
  • 第三部分 损失函数,验证函数,优化器
  • 第四部分 训练
  • 可能出现的bug
  • 拓展内容
    • 正常加载图像数据的其他方式
    • 类别索引在做什么
    • train_iter迭代器在迭代时__getitem_在干什么

第一部分 加载并读取数据

难点:如何接受并处理图像数据–使用自定义函数进行处理

import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Imageclass CustomDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):"""初始化数据集。Args:csv_file (str): 数据集的csv文件路径,其中包含图像的文件名和标签。root_dir (str): 图像文件的根目录路径。transform (callable, optional): 一个可选的转换函数,用来对图像进行处理。"""# 读取csv文件,并将数据存储到pandas DataFrame中。self.data_frame = pd.read_csv(csv_file)# 存储图像文件的根目录路径。self.root_dir = root_dir# 存储可选的图像转换函数。self.transform = transform# 将字符串类型的标签转换为整数索引,同时获取标签到整数索引的映射。self.data_frame['label'], self.label_mapping = pd.factorize(self.data_frame['label'])def __len__(self):"""返回数据集中的样本数。"""return len(self.data_frame)def __getitem__(self, idx):"""根据给定的索引idx获取对应的数据项。Args:idx (int): 数据项的索引。Returns:tuple: 包含图像和其对应标签的元组。"""# 如果idx是torch tensor类型,先转换为列表。if torch.is_tensor(idx):idx = idx.tolist()# 构建图像文件的完整路径。img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])# 打开图像文件。image = Image.open(img_name)# 获取对应的标签(整数索引)。label = self.data_frame.iloc[idx, 1]# 如果有转换函数,应用之。if self.transform:image = self.transform(image)# 返回图像和标签。return image, labeldef get_num_classes(self):"""返回数据集中不同类别的总数。"""return len(self.label_mapping)transform = transforms.Compose([transforms.Resize(256),transforms.RandomRotation(20),transforms.RandomHorizontalFlip(),transforms.CenterCrop(224),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 创建图像数据集实例
dataset = CustomDataset(csv_file='C:/Users/xiaox/pytorch/SucTest/train.csv',root_dir='C:/Users/xiaox/pytorch/SucTest',transform=transform)num_classes = dataset.get_num_classes()
print(f"Total number of classes: {num_classes}")# 数据加载和划分
from torch.utils.data import DataLoader, random_split
total_size = len(dataset)
train_size = int(total_size * 0.8)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])# 加载数据集
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
1.如何自定义Dataset以用来灵活处理图像数据[数据的变化]1.定义__init__函数(读取csv文件,图像文件,transform,标签编码) [相当于将csv文件读取到Dataframe数据类型中,将标签映射为整数] 2.定义__getitem__函数(获取图像并转换,与图像对应的标签的索引)[相当于返回一张被转换的图片 与图片对应的Label对应的整数索引]2.为什么使用类别索引将字符串映射成整数最重要的一点:神经网络中字符串无法转化为tensor类型,无法加入到net网络中3.为什么选择类别索引而不是独热编码[独热编码就是预测房价中对于各个字符串标签的处理方法]独热编码在交叉熵损失函数中不适用可拓展内容:
1.正常加载图像数据的其他方式(dataset,compose,data_loader的关系)
2.类别索引在做什么
3.train_iter迭代器在迭代时__getitem_在干什么

第二部分 定义网络

使用了Resnet50

from torch import nn
from d2l import torch as d2l
from torch.nn import functional as F
import torchvision.models as modelsmodel = models.resnet50(weights=None)  # 使用预训练的ResNet-50# 首先获取全连接层的输入特征数量
num_ftrs = model.fc.in_features# 使用Dropout层和新的全连接层创建一个新的Sequential模块
model.fc = nn.Sequential(nn.Dropout(0.5),nn.Linear(num_ftrs, 176)
)

第三部分 损失函数,验证函数,优化器

这里我使用了Adam作为优化器

#这是评估模型平均准确率的函数
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):#1net.eval()  # 设置为评估模式#2if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量#3metric = d2l.Accumulator(2)#4## 4.1with torch.no_grad():## 4.2for X, y in data_iter:### 4.2.1if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)### 4.2.2y = y.to(device)### 4.2.3 注意:d2l原有库可能表示:acc = d2l.accuracy(net(X), y) metric.add(acc * y.numel(), y.numel())print(d2l.accuracy(net(X), y))metric.add(d2l.accuracy(net(X), y), y.numel())#5return metric[0] / metric[1]#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device,weight_decay):"""用GPU训练模型(在第六章定义)"""#1def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)#2net.to(device)#更改了优化器#optimizer = torch.optim.SGD(net.parameters(), lr=lr)optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)#4loss = nn.CrossEntropyLoss()#5animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])#6timer, num_batches = d2l.Timer(), len(train_iter)#7for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数#7.1metric = d2l.Accumulator(3)#7.2net.train()#7.3for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():#7.4metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()#7.5train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]#7.6if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

第四部分 训练

## 开始训练
lr = 1e-4
batch_size = 128
num_epochs = 20
weight_decay = 1e-3train_ch6(model, train_loader, test_loader, num_epochs, lr, d2l.try_gpu(),weight_decay)

可能出现的bug

在这里插入图片描述

CUDA错误
1.检查数据类型与形状是否合理(断言测试)
2.检查网络输出种类是否正常(获取类别个数)
3.检查网络是否正常(前向输出测试)
4.检查网络每一层是否正常(循环测试)首先:可以尝试重启,有可能是把内存用完了,重启试一下在进行下面的排查

拓展内容

正常加载图像数据的其他方式


#### 当图片所在文件夹代表一个标签时使用或数据集有对应的加载函数
import torch 
from torchvision import transforms,datasets
from torch import nn
from d2l import torch as d2l# 0.定义载入图像的格式 AlexNet的输入是227
transform = transforms.Compose([transforms.Resize(256),                    # 将图像缩放,使最短边为256像素transforms.CenterCrop(227),                # 从图像中心裁剪224x224大小的图像transforms.ToTensor(),                     # 将图像转换为PyTorch张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化处理
])# 例子:读取图像数据(图片所在文件夹代表一个标签)
dataset = datasets.ImageFolder(root='C:\\Users\\xiaox\\pytorch\\SucTest\\img\\', transform=transform)# 例子:加载 CIFAR-10 数据集(数据集有对应的加载函数)
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 2.定义迭代器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

类别索引在做什么

## 标签编码定义
在处理分类问题时,尤其是在使用机器学习或深度学习模型时,通常需要将文本或字符串类型的标签(labels)转换成整数索引。这是因为大多数算法都优化以处理数值数据,而不是文本数据。在你的代码中,这个转换是通过 Pandas 的 `factorize` 函数实现的。### `pd.factorize()`
这个函数用于将一个具有重复值的数组转换为一个整数数组,其中每个唯一值都被分配一个整数标识符。它还返回一个包含原始数据中唯一值的数组,这可以作为标签到整数的映射。#### 示例解释假设你有一个CSV文件,其中包含如下的数据,其中每行代表一个样本,第一列是图像的文件名,第二列是图像的标签(如动物种类):```
image_name, label
cat001.jpg, cat
dog001.jpg, dog
cat002.jpg, cat
bird001.jpg, bird
```使用 `pd.factorize()` 函数处理 `label` 列时,会发生以下操作:labels, label_mapping = pd.factorize(['cat', 'dog', 'cat', 'bird'])
```结果:
- `labels` 会是 `[0, 1, 0, 2]`。这里,'cat' 被映射为 0,'dog' 被映射为 1,'bird' 被映射为 2。注意,第一个出现的标签('cat')是第一个被赋予新索引的。
- `label_mapping` 会是 `['cat', 'dog', 'bird']`,这是一个数组,其中索引位置对应于在 `labels` 中分配给每个唯一标签的整数。通过这种方式,原始的字符串标签被转换为整数,使得它们可以更容易地被模型处理,同时你还保持了一个从整数索引回到原始标签的映射,这在模型预测结束后,将预测的整数标签转换回人类可读的标签时非常有用。

train_iter迭代器在迭代时__getitem_在干什么

for X,y in train_iter:做了什么DataLoader 创建一个迭代器。每次迭代时,从数据集(通过 Dataset 对象)中请求下一批数据。
(既向dataset对象随即指定batch_size个索引,来获取数据)数据集的 __getitem__ 方法按索引获取数据和标签,这通常是随机访问,支持数据的随机打乱和批处理。
(dataset通过idx与__getitem__获得指定的数据然后返回给dataloader直到所有的batch_size个数据都被返回)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/321739.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么跑腿越来越受到年轻人的青睐

跑腿服务越来越受到年轻人的青睐,主要源于以下几个方面的原因: 1. 便捷快速:在快节奏的现代生活中,年轻人追求的是效率和速度。跑腿服务提供了一种即时、便捷的解决方案,使他们能够在繁忙的生活和工作中节省时间和精力…

VS Code中PlatformIO IDE的安装并开发Arduino

VS Code中PlatformIO IDE的安装并开发Arduino VS Code的安装 略 PlatformIO IDE的安装 PlatformIO IDE是是什么 PlatformIO IDE 是一个基于开源的跨平台集成开发环境(IDE),专门用于嵌入式系统和物联网(IoT)开发。…

C语言 函数概述

好 接下来 我们来讲函数 构建C程序的最佳方式 就是模块化程序设计 C语言中 最基本的程序模块被称为 函数 所以 这个知识点的重要性不言而喻 这里 我们讲个故事 诸葛亮六出祁山时 为了逼司马懿出战 派人送给力司马懿一件女人衣服 司马懿只是为使者 诸葛亮的饮食起居 使者感叹…

适合小白使用的编译器(c语言和Java编译器专属篇)

本节课主要讲如何安装适合编程小白的编译器 废话不多说,我们现在开始 c/c篇 首先,进入edge浏览器,在搜索框输入visual studio ,找到带我画圈的图标,点击downloads 找到community版(社区版)的下…

简易录制视频做3D高斯

系统环境 ubuntu20 ,cuda11.8,anaconda配置好了3D高斯的环境。 具体参考3D高斯环境配置:https://blog.csdn.net/Son_of_the_Bronx/article/details/138527329?spm1001.2014.3001.5501 colmap安装:https://blog.csdn.net/Son_of…

最后一块石头的重量 II ,目标和,一和0

最后一块石头的重量 II(0-1背包问题 将石头尽可能分为两堆重量一样的,进行相撞则为0 class Solution {public int lastStoneWeightII(int[] stones) {int sum0;for(int x:stones){sumx;}int targetsum/2;int[] dpnew int[target1];//dp[j]表示最大石堆的…

分享5款对工作学习有帮助的效率软件

​ 今天再来推荐5个超级好用的效率软件,无论是对你的学习还是办公都能有所帮助,每个都堪称神器中的神器,用完后觉得不好用你找我。 1.文件复制——ClipClip ​ ClipClip是一款功能强大、操作简便的文件复制与管理软件。它改变了传统的复制粘…

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具)

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具) 场景来源 去年单位内部的一次素拓活动,分工负责策划设置其中的“你画我猜”环节,网络上搜集到题目文字后,想着如何快速做成对应一页一页的PPT。第一时间想…

java入门详细教程——day01

目录 1. Java入门 1.1 Java是什么? 1.2 Java语言的历史 1.3 Java语言的分类 1.4 Java语言的特点 1.4.1 先编译再解释运行 1.4.2 跨平台 1.5 JRE和JDK(记忆) 1.6 JDK的下载和安装(应用) 1.6.1 下载 1.6.2 安…

SAP 【MM】移动类型的科目确定<转载>

原文链接:https://blog.csdn.net/zhongguomao/article/details/134387102 移动类型的科目确定 SAP中支持控制不同移动类型所确定的总分类帐科目和账户分配,同时也支持控制用户能否改变总分类帐科目和账户分配默认值。 1、控制能否手动输入总分类帐科目…

Golang | Leetcode Golang题解之第74题搜索二维矩阵

题目&#xff1a; 题解&#xff1a; func searchMatrix(matrix [][]int, target int) bool {m, n : len(matrix), len(matrix[0])i : sort.Search(m*n, func(i int) bool { return matrix[i/n][i%n] > target })return i < m*n && matrix[i/n][i%n] target }

一起刷C语言菜鸟教程100题(15-26含解析)

五一过的好快&#xff0c;五天假期说没就没&#xff0c;因为一些事情耽搁到现在&#xff0c;不过还是要继续学习的&#xff0c;之后就照常更新&#xff0c;先说一下&#xff0c;这个100题是菜鸟教程里面的&#xff0c;但是有一些题&#xff0c;我加入了自己的理解&#xff0c;甚…

责任链模式和观察者模式

1、责任链模式 1.1 概述 在现实生活中&#xff0c;常常会出现这样的事例&#xff1a;一个请求有多个对象可以处理&#xff0c;但每个对象的处理条件或权限不同。例如&#xff0c;公司员工请假&#xff0c;可批假的领导有部门负责人、副总经理、总经理等&#xff0c;但每个领导…

第80天:WAF 攻防-漏洞利用HPP 污染分块传输垃圾数据

案例一&#xff1a;安全狗-SQL 注入-知识点 正常访问会被拦截 like绕过 对比成功&#xff0c;正常返回 对比失败&#xff0c;不返回 post绕过 这里需要支持post注入。这里是我自己改的REQUEST 这里其实安全狗可以开启post验证&#xff0c;看别人知不知道能开启了 过滤了 模拟…

贪心算法应用例题

最优装载问题 #include <stdio.h> #include <algorithm>//排序int main() {int data[] { 8,20,5,80,3,420,14,330,70 };//物体重量int max 500;//船容最大总重量int count sizeof(data) / sizeof(data[0]);//物体数量std::sort(data, data count);//排序,排完数…

荟敏堂·中医优势专科建设新质生产力发展论坛在京召开

原题&#xff1a;《荟敏堂中医优势专科建设新质生产力发展论坛在京召开——周超凡中医治则学思想传承研讨会成功举办》 会议现场照片 仟江水商业电讯&#xff08;5月8日 北京 委托发布&#xff09;日前&#xff0c;周超凡中医治则学思想传承研讨会暨中医优势专科建设新质生产力…

QT实现Home框架的两种方式

在触摸屏开发QT界面一般都是一个Home页面&#xff0c;然后button触发进入子页面显示&#xff0c;下面介绍这个home框架实现的两种方式&#xff1a; 1.方式一&#xff1a;用stackedWidget实现 &#xff08;1&#xff09;StackedWidget控件在Qt框架中是一个用于管理多个子窗口或…

数据挖掘流程是怎样的?数据挖掘平台基本功能有哪些?

数据挖掘是从大量的、不完全的、有噪声的、模糊的、随机的数据中提取隐含在其中的、人们事先不知道的、但又是潜在有用的信息和知识的过程。 数据挖掘的流程是&#xff1a; 清晰地定义出业务问题&#xff0c;确定数据挖掘的目的。 数据准备: 数据准备包括&am…

记一次java进程频繁挂掉问题排查修复

前言 最近业务部门有个java服务进程会突然无缘无故的挂掉&#xff0c;然后这个服务会产生一堆类似hs_err_pid19287.log这样的日志。业务部门负责人就把hs_err_pidxxx的日志发给我&#xff0c;让我帮忙看下问题。本文就来回顾一下&#xff0c;我是如何帮业务部门进行问题排查 …

PyGame 文字显示问题及解决方法

在 Pygame 中显示文字时可能会遇到一些问题&#xff0c;例如文字显示不清晰、字体不正确或者文字位置不准确等。以下是一些常见的问题及其解决方法&#xff0c;具体情况可以看看情况。 1、问题背景 一位用户在使用 PyGame 库进行游戏开发时&#xff0c;遇到了一个问题&#xf…