034、test

之——全纪录

目录

之——全纪录

杂谈

正文

1.下载处理数据

2.数据集概览

3.构建自定义dataset

4.初始化网络

5.训练


杂谈

        综合方法试一下。


leaves

1.下载处理数据

        从官网下载数据集:Classify Leaves | Kaggle

        解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

        images,27153个树叶图片:

        test.csv,8800个:

        train.csv,18353个:


2.数据集概览

        训练集、测试集、类别:

#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Imagetrain_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print("训练集数量:",len(train_images))
n_train=len(train_images)
test_images=test_data.iloc[:,0].values
print("测试集数量:",len(test_images))
n_test=len(test_images)train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
# print(len(train_labels),train_labels)#记录并排序所有的类别名
train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
print("总类别:",len(train_labels_header))
classes=len(train_labels_header)


3.构建自定义dataset

       继承 torch.utils.Dataset 类,自定义树叶分类数据集:

#继承 torch.utils.Dataset 类,自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):#root数据目录, images图片路径, labels图片标签, transform数据增强def __init__(self, root, images, labels, transform):super(leaves_dataset, self).__init__()self.root = rootself.images = imagesif labels is None:self.labels = Noneelse:self.labels = labelsself.transform = transform#获得指定样本def __getitem__(self, index):image_path = self.root + self.images[index]image = Image.open(image_path)#预处理image = self.transform(image)if self.labels is None:return imagelabel = torch.tensor(self.labels[index])return image, label#获得数据集长度def __len__(self):return self.images.shape[0]

        构建读取数据与预处理:

def load_data(images, labels, batch_size, train):aug = []normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])if (train):aug = [torchvision.transforms.CenterCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),transforms.ToTensor(),normalize]else:aug = [torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),transforms.ToTensor(),normalize]transform = transforms.Compose(aug)dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)if train==True:type="训练"else:type="测试"print("载入:",dataset.__len__(),type)return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)train_iter = load_data(train_images, train_labels, 512, train=True)

4.初始化网络

        使用官方预训练模型初始化网络,并修改输出类别数:

#初始化网络
net = torchvision.models.resnet18(pretrained=True)net.fc = nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc


5.训练

         定义迭代器、优化器以及其他超参数,进行训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,param_group=True):train_slices = random.sample(list(range(n_train)), 15000)test_slices = list(set(range(n_train)) - set(train_slices))train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]#别的层不变,最后一层10倍学习率trainer = torch.optim.Adam([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0.001)print(111)try:d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)except Exception as e:print(e)#%%#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)

        小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


CIFAR-10

1.数据集


2.未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/195501.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux :远程访问的 16 个最佳工具(一)

通过远程桌面协议 (RDP) 可以访问远程 Linux 桌面计算机,这是 Microsoft 开发的专有协议。它为用户提供了一个图形界面,可以通过网络连接连接到另一台/远程计算机。 FreeRDP 是 RDP 的免费实现。 RDP以客户端/服务器模型工作,其中远程计算机必…

服务器集群配置LDAP统一认证高可用集群(配置tsl安全链接)-centos9stream-openldap2.6.2

写在前面 因之前集群为centos6,已经很久没升级了,所以这次配置统一用户认证也是伴随系统升级到centos9时一起做的配套升级。新版的openldap配置大致与老版本比较相似,但有些地方配置还是有变化,另外,铺天盖地的帮助文…

R语言绘制精美图形 | 火山图 | 学习笔记

一边学习,一边总结,一边分享! 教程图形 前言 最近的事情较多,教程更新实在是跟不上,主要原因是自己没有太多时间来学习和整理相关的内容。一般在下半年基本都是非常忙,所有一个人的精力和时间有限&#x…

「Verilog学习笔记」使用8线-3线优先编码器Ⅰ实现16线-4线优先编码器

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 分析 当EI10时、U1禁止编码,其输出端Y为000,GS1、EO1均为0。同时EO1使EI00,U0也禁止编码,其输出端及GS0、EO0均为0。由电路…

Postman内置动态参数以及自定义的动态参数

近期在复习Postman的基础知识,在小破站上跟着百里老师系统复习了一遍,也做了一些笔记,希望可以给大家一点点启发。 一)内置动态参数 {{$timestamp}} 生成当前时间的时间戳{{$randomInt}} 生成0-1000之间的随机数{{$guid}} 生成随…

Ansys Electronics Desktop仿真——HFSS线圈寄生电阻,电感

利用ANSYS Electronics Desktop,可在综合全面、易于使用的设计平台中集成严格的电磁场分析和系统电路仿真。按需求解器技术让您能集成电磁场仿真器和电路及系统级仿真,以探索完整的系统性能。 HFSS(High Frequency Structure Simulator&#…

matplotlib 绘制双纵坐标轴图像

效果图: 代码: 由于使用了两组y axis,如果直接使用ax.legend绘制图例,会得到两个图例。而下面的代码将两个图例合并显示。 import matplotlib.pyplot as plt import numpy as npdata np.random.randint(low0,high5,size(3,4)) …

串口通信原理及应用

Content 1. 前言介绍2. 连接方式3. 数据帧格式4. 代码编写 1. 前言介绍 串口通信是一种设备间非常常用的串行接口,以比特位的形式发送或接收数据,由于成本很低,容易使用,工程师经常使用这种方式来调试 MCU。 串口通信应用广泛&a…

设计模式-备忘录模式-笔记

动机(Motivation) 在软件构建过程中,某些对象的状态在转换过程中,可能由于某种需要,要求程序能够回溯到对象之前处于某个点时的状态。如果使用一些公有接口来让其他对象得到对象的状态,便会暴露对象的细节…

【AI视野·今日Robot 机器人论文速览 第六十五期】Mon, 30 Oct 2023

AI视野今日CS.Robotics 机器人学论文速览 Mon, 30 Oct 2023 Totally 18 papers 👉上期速览✈更多精彩请移步主页 Daily Robotics Papers Gen2Sim: Scaling up Robot Learning in Simulation with Generative Models Authors Pushkal Katara, Zhou Xian, Katerina F…

在listener.ora配置文件中配置listener 1527的监听并且使用tnsnames连接测试

文章目录 前言:一、命令语句实现1、监听介绍2、编辑 listener.ora 文件:寻找配置文件对配置文件进行配置 3、重启监听4、配置TNS 二、图形化界面实现1、listener.ora文件配置2、tnsnames.ora文件配置 三、测试连接 前言: 命令实现和图形化实…

Mysql数据库 16.SQL语言 数据库事务

一、数据库事务 数据库事务介绍——要么全部成功要么全部失败 我们把完成特定的业务的多个数据库DML操作步骤称之为一个事务 事务——就是完成同一个业务的多个DML操作 例: 数据库事务四大特性 原子性(A):一个事务中的多个D…

基于JavaWeb+SSM+购物系统微信小程序的设计和实现

基于JavaWebSSM购物系统微信小程序的设计和实现 源码获取入口前言主要技术系统设计功能截图Lun文目录订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 前言 第一章 绪 论 1.1选题背景 互联网是人类的基本需求,特别是在现代社会,…

亲测一款超实用的在线制作产品册工具,一看就会

最近,我一直在寻找一款简单易用的在线制作产品册工具,终于让我找到了一个超实用的神器!这款工具不仅功能强大,而且操作简单,一看就会。 首先,这款工具提供了丰富的模板和素材,用户可以根据自己的…

自动驾驶汽车:人工智能最具挑战性的任务

据说,自动驾驶汽车是汽车行业梦寐以求的状态,将彻底改变交通运输业。就在几年前,对自动驾驶汽车的炒作风靡一时,那么到底发生了什么呢?这么多公司吹嘘到2021年我们将迎来的无人驾驶汽车革命在何处?事实证明…

基于深度学习的活体人脸识别检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1. 活体人脸识别检测算法概述 4.2. 深度学习在活体人脸识别检测中的应用 4.3. 算法流程 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022a 3.部分核心程序 …

【论文阅读】A Survey on Video Diffusion Models

视频扩散模型(Video Diffusion Model)最新综述GitHub 论文汇总-A Survey on Video Diffusion Models。 paper:[2310.10647] A Survey on Video Diffusion Models (arxiv.org) 0. Abstract 本文介绍了AIGC时代视频扩散模型的全面回顾。简要介…

C++网络编程库编写自动爬虫程序

首先&#xff0c;我们需要使用 C 的网络编程库来编写这个爬虫程序。以下是一个简单的示例&#xff1a; #include <iostream> #include <string> #include <curl/curl.h> #include <openssl/ssl.h>const char* proxy_host "duoip"; const in…

.NET8.0 AOT 经验分享 FreeSql/FreeRedis/FreeScheduler 均已通过测试

2023年11月15日&#xff0c;对.net的开发圈是一个重大的日子&#xff0c;.net 8.0正式版发布。 圈内已经预热了有半个月有余&#xff0c;性能不断超越&#xff0c;开发体验越来越完美&#xff0c;早在.net 5.0的时候就各种吹风Aot编译&#xff0c;直到6.0 7.0使用仍然比较麻烦…

SQL学习之增删改查

文章目录 数据库数据类型建表create table插入数据insert into查询数据select from修改数据update set删除数据delete from备份ctas结果插入iis截断表 truncate table修改表结构alter table添加注释 注&#xff1a;本文的SQL语法是基于Oracle数据库操作的&#xff0c;但是基本的…