HW03 -实物图像识别-改进:图像增强、网络架构,K折交叉验证

在这里插入图片描述

修改模型架构或者进行图像增强

# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)#transforms.CenterCrop()transforms.RandomResizedCrop((128, 128), scale=(0.7, 1.0)),#transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),transforms.RandomHorizontalFlip(0.5),transforms.RandomVerticalFlip(0.5),transforms.RandomRotation(180),transforms.RandomAffine(30),#transforms.RandomInvert(p=0.2),#transforms.RandomPosterize(bits=2),#transforms.RandomSolarize(threshold=192.0, p=0.2),#transforms.RandomEqualize(p=0.2),transforms.RandomGrayscale(p=0.2),transforms.ToTensor(),#transforms.RandomApply(torch.nn.ModuleList([]))# You may add some transforms here.# ToTensor() should be the last one of the transforms.
])

修改模型架构, 建立 resudial

class Residual_Block(nn.Module):def __init__(self, ic, oc, stride=1):# torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)# torch.nn.MaxPool2d(kernel_size, stride, padding)super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(ic, oc, kernel_size=3, stride=stride, padding=1),nn.BatchNorm2d(oc),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(oc, oc, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(oc),)self.relu = nn.ReLU(inplace=True)self.downsample = Noneif stride != 1 or (ic != oc):self.downsample = nn.Sequential(nn.Conv2d(ic, oc, kernel_size=1, stride=stride),nn.BatchNorm2d(oc),)def forward(self, x):residual = xout = self.conv1(x)out = self.conv2(out)if self.downsample:residual = self.downsample(x)out += residualreturn self.relu(out)class Classifier(nn.Module):def __init__(self, block, num_layers, num_classes=11):super().__init__()self.preconv = nn.Sequential(nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(32),nn.ReLU(inplace=True),)self.layer0 = self.make_residual(block, 32, 64,  num_layers[0], stride=2)self.layer1 = self.make_residual(block, 64, 128, num_layers[1], stride=2)self.layer2 = self.make_residual(block, 128, 256, num_layers[2], stride=2)self.layer3 = self.make_residual(block, 256, 512, num_layers[3], stride=2)#self.avgpool = nn.AvgPool2d(2)self.fc = nn.Sequential(            nn.Dropout(0.4),nn.Linear(512*4*4, 512),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.Dropout(0.2),nn.Linear(512, 11),)def make_residual(self, block, ic, oc, num_layer, stride=1):layers = []layers.append(block(ic, oc, stride))for i in range(1, num_layer):layers.append(block(oc, oc))return nn.Sequential(*layers)def forward(self, x):# [3, 128, 128]out = self.preconv(x)  # [32, 64, 64]out = self.layer0(out) # [64, 32, 32]out = self.layer1(out) # [128, 16, 16]out = self.layer2(out) # [256, 8, 8]out = self.layer3(out) # [512, 4, 4]#out = self.avgpool(out) # [512, 2, 2]out = self.fc(out.view(out.size(0), -1)) return out

修改损失函数,使用Focalloss

import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super().__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs, dim=1)class_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]probs = (P*class_mask).sum(1).view(-1, 1)log_p = probs.log()batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_pif self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return lossclass MyCrossEntropy(nn.Module):def __init__(self, class_num):pass

K 折交叉验证

k_fold = 4
num = len(total_files) // k_fold
# "cuda" only when GPUs are available.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)# The number of training epochs and patience.# Initialize a model, and put it on the device specified.#from torchsummary import summary
#summary(model, (3, 128, 128))
# For the classification task, we use cross-entropy as the measurement of performance.
#criterion = nn.CrossEntropyLoss()# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.# Initialize trackers, these are not parameters and should not be changedtest_fold = k_foldfor i in range(test_fold):fold = i+1print(f'\n\nStarting Fold: {fold} ********************************************')model = Classifier(Residual_Block, num_layers).to(device)criterion = FocalLoss(11, alpha=alpha)optimizer = torch.optim.Adam(model.parameters(), lr=0.0004, weight_decay=2e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=16, T_mult=1)stale = 0best_acc = 0val_data = total_files[i*num: (i+1)*num]train_data = total_files[:i*num] + total_files[(i+1)*num:]train_set = FoodDataset(tfm=train_tfm, files=train_data)train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)valid_set = FoodDataset(tfm=test_tfm, files=val_data)valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)for epoch in range(n_epochs):# ---------- Training ----------# Make sure the model is in train mode before training.model.train()# These are used to record information in training.train_loss = []train_accs = []lr = optimizer.param_groups[0]["lr"]pbar = tqdm(train_loader)pbar.set_description(f'T: {epoch+1:03d}/{n_epochs:03d}')for batch in pbar:# A batch consists of image data and corresponding labels.imgs, labels = batch#imgs = imgs.half()#print(imgs.shape,labels.shape)# Forward the data. (Make sure data and model are on the same device.)logits = model(imgs.to(device))# Calculate the cross-entropy loss.# We don't need to apply softmax before computing cross-entropy as it is done automatically.loss = criterion(logits, labels.to(device))# Gradients stored in the parameters in the previous step should be cleared out first.optimizer.zero_grad()# Compute the gradients for parameters.loss.backward()# Clip the gradient norms for stable training.grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)# Update the parameters with computed gradients.optimizer.step()# Compute the accuracy for current batch.acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()# Record the loss and accuracy.train_loss.append(loss.item())train_accs.append(acc)pbar.set_postfix({'lr':lr, 'b_loss':loss.item(), 'b_acc':acc.item(),'loss':sum(train_loss)/len(train_loss), 'acc': sum(train_accs).item()/len(train_accs)})scheduler.step()# Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.model.eval()# These are used to record information in validation.valid_loss = []valid_accs = []# Iterate the validation set by batches.pbar = tqdm(valid_loader)pbar.set_description(f'V: {epoch+1:03d}/{n_epochs:03d}')for batch in pbar:# A batch consists of image data and corresponding labels.imgs, labels = batch#imgs = imgs.half()# We don't need gradient in validation.# Using torch.no_grad() accelerates the forward process.with torch.no_grad():logits = model(imgs.to(device))# We can still compute the loss (but not the gradient).loss = criterion(logits, labels.to(device))# Compute the accuracy for current batch.acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()# Record the loss and accuracy.valid_loss.append(loss.item())valid_accs.append(acc)pbar.set_postfix({'v_loss':sum(valid_loss)/len(valid_loss), 'v_acc': sum(valid_accs).item()/len(valid_accs)})#break# The average loss and accuracy for entire validation set is the average of the recorded values.valid_loss = sum(valid_loss) / len(valid_loss)valid_acc = sum(valid_accs) / len(valid_accs)if valid_acc > best_acc:print(f"Best model found at fold {fold} epoch {epoch+1}, acc={valid_acc:.5f}, saving model")torch.save(model.state_dict(), f"Fold_{fold}_best.ckpt")# only save best to prevent output memory exceed errorbest_acc = valid_accstale = 0else:stale += 1if stale > patience:print(f"No improvment {patience} consecutive epochs, early stopping")break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/222189.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue中的加密方式(js-base64、crypto-js、jsencrypt、bcryptjs)

目录 1.安装js-base64库 2. 在Vue组件中引入js-base64库 3.使用js-base64库进行加密 4.Vue中其他加密方式 1.crypto-js 2.jsencrypt 3.bcryptjs 1.安装js-base64库 npm install js-base64 --save-dev 2. 在Vue组件中引入js-base64库 import { Base64 } from js-ba…

大型语言模型:SBERT — Sentence-BERT

slavahead 一、介绍 Transformer 在 NLP 方面取得了进化进步,这已经不是什么秘密了。基于转换器,许多其他机器学习模型已经发展起来。其中之一是BERT,它主要由几个堆叠的变压器编码器组成。除了用于情感分析或问答等一系列不同的问题外&#…

AI数字人互动大屏支持多种场景交互!

互动大屏(技术支持:zhibo175)本身具有令人瞩目的效果,再配置丰富多彩的多媒体,如引人注目的广告、特效或游戏等,可起到很好的引流作用。在空间开阔且客流密集的场所,使用各种形态的大面积屏幕&a…

【AIGC重塑教育】AI大模型驱动的教育变革与实践

文章目录 🍔现状🛸解决方法✨为什么要使用ai🎆彩蛋 🍔现状 AI正迅猛地改变着我们的生活。根据高盛发布的一份报告,AI有可能取代3亿个全职工作岗位,影响全球18%的工作岗位。在欧美,或许四分之一…

【2023 英特尔On技术创新大会直播 |我与英特尔的初次相遇】—— AIPC探索下一代的物联网时代

🌈个人主页: Aileen_0v0 🔥系列专栏:英特尔技术学习专栏 💫个人格言:"没有罗马,那就自己创造罗马~" 目录 硅谷经济的发展与挑战 Intel开发者云与AI技术的应用 AI压缩技术的发展与应用 英特尔与阿里巴巴在AI领域的合作 AIPC时代的…

flink sql1.18.0连接SASL_PLAINTEXT认证的kafka3.3.1

阅读此文默认读者对docker、docker-compose有一定了解。 环境 docker-compose运行了一个jobmanager、一个taskmanager和一个sql-client。 如下: version: "2.2" services:jobmanager:image: flink:1.18.0-scala_2.12container_name: jobmanagerports:…

C语言——数组

一、数组介绍 C 语言支持数组数据结构,它可以存储一个固定大小的相同类型元素的顺序集合。数组是用来存储一系列数据,但它往往被认为是一系列相同类型的变量。 ps:再C99之前的标准不支持变长数组,C99及之后的标准支持变长数组&a…

软件系统质量保证计划书

本计划描述了信息系统项目质量保证工作相关的一些情况,是软件质量保证过程和方针在项目中的具体实施计划。 计划中阐述了质量保证工作的基本目标;项目的基本情况;质量保证工作所需的资源;质量保证的主要工作;工作量估算…

openGuass:极简版安装

目录 一、openGauss简介 二、初始化安装环境 1.创建安装用户 2.修改文件句柄设置 ​3.修改SEM内核参数 4.关闭防火墙 6.禁用SELINUX 7.安装依赖软件 8.重启服务器 三、安装数据库 1.下载安装包 2.创建安装目录 3.解压安装包 4.执行安装 5.验证安装 四、gsql工具…

循环神经网络中的梯度消失或梯度爆炸问题产生原因分析(二)

上一篇中讨论了一般性的原则,这里我们具体讨论通过时间反向传播(backpropagation through time,BPTT)的细节。我们将展示目标函数对于所有模型参数的梯度计算方法。 出于简单的目的,我们以一个没有偏置参数的循环神经…

医院设置反馈投诉建议二维码的好处

将投诉建议的记录单制作成二维码,放在导医台、挂号窗口、门诊门口、电梯等公共区域,群众在就医过程中遇到的种种难点、堵点,皆可通过扫码进行评价、投诉,医院会及时收到信息安排员工第一时间与投诉人联系沟通解决,做到“码”上建议,马上落实。…

什么牌子的猫罐头健康又实惠?五大猫罐头推荐排行榜

新手养猫很容易陷入疯狂购买的模式,但有些品牌真的不能乱买!现在的大环境不太好,我们需要学会控制自己的消费欲望,把钱花在刀刃上!现在宠物市场真的很内卷,很多品牌都在比拼产品的数据和营养成分。很多铲屎…

PyQt6 基类QObject类介绍以及应用

锋哥原创的PyQt6视频教程: 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计51条视频,包括:2024版 PyQt6 Python桌面开发 视频教程(无废话版…

Ubuntu 22.04 LTS上安装Docker-ce

在Ubuntu 22.04 LTS上安装Docker-ce Docker是一个开源平台,用于自动化应用程序的部署、扩展和管理。它使用容器技术,使开发、测试和部署过程更加简化和可靠。本文将介绍在Ubuntu 22.04 LTS上安装Docker-ce的步骤。 步骤1:更新软件包列表 …

qt-C++笔记之使用QLabel和QPushButton实现一个bool状态的指示灯

qt-C笔记之使用QLabel和QPushButton实现一个bool状态的指示灯 code review! 文章目录 qt-C笔记之使用QLabel和QPushButton实现一个bool状态的指示灯1.QPushButton实现2.QLabel实现2.QLabel实现-对错符号 1.QPushButton实现 运行 代码 #include <QtWidgets>class Ind…

学习Java第74天,Ajax简介

什么是ajax AJAX Asynchronous JavaScript and XML&#xff08;异步的 JavaScript 和 XML&#xff09;。 AJAX 不是新的编程语言&#xff0c;而是一种使用现有标准的新方法。 AJAX 最大的优点是在不重新加载整个页面的情况下&#xff0c;可以与服务器交换数据并更新部分网页…

Web前端-JavaScript(对象)

文章目录 1.对象1.1 概念1.2 创建对象三种方式**对象字面量创建对象**&#xff1a;new Object创建对象构造函数创建对象 1.3 遍历对象 2.作用域1.1 概述1.2 全局作用域1.3 局部作用域1.4 JS没有块级作用域1.5 变量的作用域1.6 作用域链1.7 预解析 1.对象 1.1 概念 什么是对象 …

如何衡量和提高测试覆盖率?

衡量和提高测试覆盖率&#xff0c;对于尽早发现软件缺陷、提高软件质量和用户满意度&#xff0c;都具有重要意义。如果测试覆盖率低&#xff0c;意味着用例未覆盖到产品的所有代码路径和场景&#xff0c;这可能导致未及时发现潜在缺陷&#xff0c;代码中可能存在逻辑错误、边界…

通讯录应用程序开发指南

目录 一、前言 二、构建通讯录应用程序 2.1通讯录框架 (1)打印菜单 (2) 联系人信息的声明 (3)创建通讯录 (4)初始化通讯录 2.2功能实现 (1)增加联系人 (2)显示联系人 (3)删除联系人 (4)查找联系人 (5)修改联系人 (6)排序联系人 三、通讯录的优化 3.1 文件存储 …

2. 创建型模式 - 抽象工厂模式

亦称&#xff1a; Abstract Factory 意图 抽象工厂模式是一种创建型设计模式&#xff0c; 它能创建一系列相关的对象&#xff0c; 而无需指定其具体类。 问题 假设你正在开发一款家具商店模拟器。 你的代码中包括一些类&#xff0c; 用于表示&#xff1a; 一系列相关产品&…