李宏毅 2022机器学习 HW3 boss baseline 上分记录

作业数据是所有数据都有标签的版本。

李宏毅 2022机器学习 HW3 boss baseline 上分记录

    • 1. 训练数据增强, private 0.76056
    • 2. cross validation&ensemble, private 0.81647
    • 3. test dataset augmentation, private 0.82458
    • 4. resnet, private 0.86555
    • 5. Image Normalization, private 0.87494
    • 6. 减小batch_size, private 0.895
    • 7. label smoothing

1. 训练数据增强, private 0.76056

结论:训练数据增强、更长时间的训练、dropout都证明很有效果,实验效果提升至接近strong baseline

增强1:crop + geometry
增强2:crop + geometry + gray
另外epochs数目增加到100,patience增加到10个epochs,FC层增加 dropout(0.3)

增强代码如下

#训练数据增强代码train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),
])

具体实验结果如下:
在这里插入图片描述

2. cross validation&ensemble, private 0.81647

使用5-fold cross validation,划分的时候使用分层抽样,
2.1)epochs=100, patience=10
训练时发现通常在60 epochs左右就early stop了,最终public score不如之前,但private score有提升,说明cross validation在过拟合上还是有效果的。
在这里插入图片描述
2.2)epochs=100, patience=16,再看看效果
patience增大后,效果有了一个非常明显的提升,超过strong baseline。具体看实验过程,会发现之前patience=10的时候,基本60epochs就停了,而现在patience=100的时候,early stop没有起作用,都是训练满100个epochs。猜测应该是使用5-fold的cross validation时,对比默认的train/valid,一方面训练数据更多,另一方面valid数据变少波动性更大,所以应该给更多的时间训练。
在这里插入图片描述

3. test dataset augmentation, private 0.82458

结论:此方式有效,分数进一步提升
在这里插入图片描述
测试数据的具体增强方式如下:
在步骤2的基础上,对test数据集使用了train数据集的数据增强方式,生成5张图片预测,对预测结果值平均,然后再用这个结果与原预测结果平均。以下为作业PPT相关部分。
在这里插入图片描述

4. resnet, private 0.86555

使用torchvision自带的resnet模型(按照作业要求,pretrained=False),尝试了resnet18和resnet50,效果进一步有了明显提升。public榜上超过bossline,但是从private榜上,可以看出存在一定过拟合。 另外resnet50的效果并没有比resnet18好,可能是小数据集的原因。这里均使用epochs=200,patience=16, lr=0.0003, weight_decay=1e-5。
在这里插入图片描述
在这里插入图片描述

两个注意点:
1,图片size设成224x224(论文中使用的图片尺寸),对比了128和224,两者差别很大。
2,resnet中的全连接层需要从原来的1000改成此次任务预测的类别数目11,代码如下:

def model_resnet():resnet = resnet18(pretrained=False)resnet.fc = nn.Sequential(nn.Linear(resnet.fc.in_features, 512),nn.ReLU(),nn.Dropout(0.3),nn.Linear(512, 11))return resnet

5. Image Normalization, private 0.87494

尝试了image normalization,略微有一些提升,尤其是做了test augmentation的private榜,达到了目前最高分。
在这里插入图片描述
image normalization的代码如下,注意train和test需要加上同样的norm。

height, width = 224, 224# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([# transforms.Resize((128, 128)),transforms.Resize((height, width)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),# transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),transforms.RandomResizedCrop(size=(height, width), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

6. 减小batch_size, private 0.895

将batch_size从64减小到16,模型效果进一步提升,加上image normalization后,private和public双双达到目前最高分。
在这里插入图片描述

7. label smoothing

label smoothing是为了让模型不要那么武断,加强模型鲁棒性的手段,尤其可以介绍标签错误带来的影响。
尝试了label_smoothing 0.1和0.05,其余超参保持不变,实验结果不如之前。pytorch的crossentropyloss里面有label_smoothing参数,直接设置就好。

# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/154563.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DP4054H完全兼容替代TP4054 36V 耐压 500mA 线性锂电充电芯片

产品概述: DP4054H是一款完整的采用恒定电流/恒定电压单节锂离子电池充电管理芯片。其SOT小封装和较少的外部元件数目使其成为便携式应用的理想器件,DP4054H可以适合USB 电源和适配器电源工作。由于采用了内部PMOSFET架构,加上防倒充电 路&am…

【微服务】RedisSearch 使用详解

目录 一、RedisJson介绍 1.1 RedisJson是什么 1.2 RedisJson特点 1.3 RedisJson使用场景 1.3.1 数据结构化存储 1.3.2 实时数据分析 1.3.3 事件存储和分析 1.3.4 文档存储和检索 二、当前使用中的问题 2.1 刚性数据库模式限制了敏捷性 2.2 基于磁盘的文档存储导致瓶…

【20】c++设计模式——>组合模式

组合模式定义 C组合模式(Composite Pattern)是一种结构型设计模式,他允许将对象组合成树形结构来表示“部分-整体”的层次结构;在组合模式中有两种基本类型的对象:叶子对象和组合对象,叶子对象时没有子对象…

Websocket获取B站直播间弹幕教程——第二篇、解包/拆包

教程一、Websocket获取B站直播间弹幕教程 — 哔哩哔哩直播开放平台 1、封包 我们连接上B站Websocket成功后,要做两件事情: 第一、发送鉴权包。第二、发送心跳包,每30秒一次,维持websocket连接。 这两个包不是直接发送过去&…

无为WiFi的一批服务器

我们在多个地区拥有高速服务器&#xff0c;保证网速给力&#xff0c;刷片无压力 嘿嘿 <?phpinclude("./includes/common.php"); $actisset($_GET[act])?daddslashes($_GET[act]):null; $urldaddslashes($_GET[url]); $authcodedaddslashes($_GET[authcode]);he…

RxJava介绍及基本原理

随着互联网的迅猛发展&#xff0c;Java已成为最广泛应用于后端开发的语言之一。而在处理异步操作和事件驱动编程方面&#xff0c;传统的Java多线程并不总是最佳选择。这时候&#xff0c;RxJava作为一个基于观察者模式、函数式编程和响应式编程理念的库&#xff0c;为我们提供了…

30WSIP网络有源号角 50W SIP网络有源号角

30WSIP网络有源号角 50W SIP网络有源号角 SIP-7044是我司的一款SIP网络有源号角&#xff0c;具有10/100M以太网接口&#xff0c;内置有一个高品质扬声器&#xff0c;将网络音源通过自带的功放和喇叭输出播放&#xff0c;可达到功率30W。SIP-7044作为SIP系统的播放终端&#xf…

【photoshop学习】用 Photoshop 做的 15 件创意事

用 Photoshop 做的 15 件创意事 每个人总是谈论 Photoshop 的无限可能。您可以使用该程序做很多事情&#xff0c;列表几乎是无穷无尽的。 嘿&#xff0c;我是卡拉&#xff01;如果您花过一些时间使用 在线ps&#xff0c;您可能见过我&#xff08;并且注意到我提到了这一点&am…

一个完整的初学者指南Django-part1

源自&#xff1a;https://simpleisbetterthancomplex.com/series/2017/09/04/a-complete-beginners-guide-to-django-part-1.html 一个完整的初学者指南Django - 第1部分 介绍 今天我将开始一个关于 Django 基础知识的新系列教程。这是一个完整的 Django 初学者指南。材料分为七…

iPhone 15分辨率,屏幕尺寸,PPI 详细数据对比 iPhone 15 Plus、iPhone 15 Pro、iPhone 15 Pro Max

史上最全iPhone 机型分辨率&#xff0c;屏幕尺寸&#xff0c;PPI详细数据&#xff01;已更新到iPhone 15系列&#xff01; 点击放大查看高清图 &#xff01;

[ValueError: not enough values to unpack (expected 3, got 2)]

项目场景&#xff1a; 在使用opencv进行关键点识别、边缘轮廓提取的时候&#xff0c;提示以上错误。 import cv2 import numpy as npdef preprocess(image):# 进行图像预处理&#xff08;例如灰度化、高斯模糊等&#xff09;gray cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)blu…

【angular】TodoList小项目(已开源)

参考&#xff1a;https://segmentfault.com/a/1190000013519099 文章目录 准备工作headerTodo、Doing、Done样式&#xff08;HTMLCSS&#xff09;功能&#xff08;TS&#xff09;将输入框内容加入todoList&#xff08;addTodo&#xff09;将todo事件改到doing 服务 参考开源 效…

点向行列连边的网络流图优化成行列连边的二分图:CF1592F2

https://www.luogu.com.cn/problem/CF1592F2 做完F1&#xff0c;然后用1的结论来思考。 场上推了几个性质。首先op4的操作行列必然两两不同&#xff0c;所以op4最多 max ⁡ ( n , m ) \max(n,m) max(n,m) 次。然后手玩发现只有除 ( n , m ) (n,m) (n,m) 的三个格子都为1&am…

MySQL连接方式: Unix套接字 TCP/IP

今天连接mysql数据库使用mysql -u root -p指令的时候遇到了这个问题&#xff1a; 解决之后来总结一下mysql的连接方式 文章目录 1. Unix套接字&#xff08;或Windows命名管道&#xff09;特点&#xff1a;场景&#xff1a; 2. TCP/IP特点&#xff1a;场景&#xff1a; 3.对比总…

CAD(计算机辅助设计)软件的开发框架

CAD&#xff08;计算机辅助设计&#xff09;软件的开发通常使用特定的CAD开发框架和工具。这些框架提供了一组API&#xff08;应用程序编程接口&#xff09;和开发工具&#xff0c;使开发人员能够创建自定义插件、应用程序和功能。以下是一些常见的CAD开发框架和平台&#xff0…

041:mapboxGL移动到到某Layer上,更换鼠标形状

第041个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中通过鼠标hover的方式来更换鼠标形状。 通过mouseenter和mouseleave的方法,经过某个图层上的时候,更换鼠标的形状,从default到pointer。 离开后从pointer到default。 直接复制下面的 vue+mapbox源代码,操…

excel单元格合并策略

excel单元格合并策略 证明112&#xff1f; 要证明112这个问题&#xff0c;首先我们要找到问题的关键。所谓问题的关键呢&#xff0c;就是关键的问题&#xff0c;那么如何找到问题的关键就是这个问题的关键。 比如说&#xff0c;你有一个苹果&#xff0c;我也有一个苹果&#x…

Adobe Premiere Pro:掌控视频剪辑的魔法之手,让你的创作腾飞!

&#x1f389;&#x1f389;欢迎来到我的CSDN主页&#xff01;&#x1f389;&#x1f389; &#x1f3c5;我是尘缘&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f449;点击这里&#xff0c;就可以查看我的主页啦&#xff01;&#x1f447;&#x…

DC-DC模块升压电源直流隔离低压升高压正负输出变换器

特点 效率高达 80%以上1*1英寸标准封装电源正负双输出稳压输出工作温度: -40℃~85℃阻燃封装&#xff0c;满足UL94-V0 要求温度特性好可直接焊在PCB 上 应用 HRA 0.2~8W 系列模块电源是一种DC-DC升压变换器。该模块电源的输入电压分为&#xff1a;4.5~9V、9~18V、及18~36V、36…

基于YOLOv5的火灾烟雾检测系统

目录 1&#xff0c;YOLOv5算法原理介绍 2&#xff0c;代码实现 3&#xff0c;结果展示 1&#xff0c;YOLOv5算法原理介绍 YOLOv5是目前应用广泛的目标检测算法之一&#xff0c;其主要结构分为两个部分&#xff1a;骨干网络和检测头。 骨干网络采用的是CSPDarknet53&#xff…