【PyTorch】多对象分割项目

 【PyTorch】单对象分割项目

对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

from torchvision.datasets import VOCSegmentation
from PIL import Image   
from torchvision.transforms.functional import to_tensor, to_pil_imageclass myVOCSegmentation(VOCSegmentation):def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:augmented= self.transforms(image=np.array(img), mask=np.array(target))img = augmented['image']target = augmented['mask']                  target[target>20]=0img= to_tensor(img)            target= torch.from_numpy(target).type(torch.long)return img, targetfrom albumentations import (HorizontalFlip,Compose,Resize,Normalize)mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h,w=520,520transform_train = Compose([ Resize(h,w),HorizontalFlip(p=0.5), Normalize(mean=mean,std=std)])transform_val = Compose( [ Resize(h,w),Normalize(mean=mean,std=std)])            path2data="./data/"    
train_ds=myVOCSegmentation(path2data, year='2012', image_set='train', download=False, transforms=transform_train) 
print(len(train_ds)) val_ds=myVOCSegmentation(path2data, year='2012', image_set='val', download=False, transforms=transform_val)
print(len(val_ds)) 
import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")def show_img_target(img, target):if torch.is_tensor(img):img=to_pil_image(img)target=target.numpy()for ll in range(num_classes):mask=(target==ll)img=mark_boundaries(np.array(img) , mask,outline_color=COLORS[ll],color=COLORS[ll])plt.imshow(img)def re_normalize (x, mean = mean, std= std):x_r= x.clone()for c, (mean_c, std_c) in enumerate(zip(mean, std)):x_r [c] *= std_cx_r [c] += mean_creturn x_r

 展示训练数据集示例图像

img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

展示验证数据集示例图像

img, mask = val_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

from torchvision.models.segmentation import deeplabv3_resnet101
import torchmodel=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

from torch import nnmodel.eval()
with torch.no_grad():for xb, yb in val_dl:yb_pred = model(xb.to(device))yb_pred = yb_pred["out"].cpu()print(yb_pred.shape)    yb_pred = torch.argmax(yb_pred,axis=1)break
print(yb_pred.shape)plt.figure(figsize=(20,20))n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

from torch import nn
criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)def loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), Nonefrom torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

训练和验证模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)["out"]loss_b, _ = loss_batch(loss_func, output, yb, opt)running_loss += loss_bif sanity_check is True:breakloss=running_loss/float(len_data)return loss, Noneimport copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f" %(train_loss))print("val loss: %.6f" %(val_loss))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 10,"optimizer": opt,"loss_func": criterion,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": True,"lr_scheduler": lr_scheduler,"path2weights": path2models+"sanity_weights.pt",
}model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/389668.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025浙江(杭州)国际安防产品展览会(浙江安博会)

2025浙江&#xff08;杭州&#xff09;国际智慧城市与安防产品展览会 2025hangzhou smart city And Security Expo 时间:2025年4月23-25日 地点:杭州国际博览中心 展会介绍 浙江&#xff08;杭州&#xff09;国际智慧城市及安防产品博览会&#xff08;简称:浙江安博会&#…

告别繁琐,AI助你轻松制作PPT!2024四大工具推荐

PPT是现代商务和教育领域中不可或缺的工具。然而&#xff0c;制作一份高质量的PPT往往需要花费大量的时间和精力。AI PPT制作工具的出现可以很好地解决这一问题。下面为大家推荐几个AI PPT制作工具。 笔灵AIPPT&#xff1a;智能设计&#xff0c;一键生成 链接&#xff1a;htt…

基于JSP、java、Tomcat三者的项目实战--校园交易网(3)主页--添加商品功能

技术支持&#xff1a;JAVA、JSP 服务器&#xff1a;TOMCAT 7.0.86 编程软件&#xff1a;IntelliJ IDEA 2021.1.3 x64 前文三篇登录和注册功能的实现 基于JSP、java、Tomcat、mysql三层交互的项目实战--校园交易网&#xff08;1&#xff09;-项目搭建&#xff08;前期准备工作…

Golang | Leetcode Golang题解之第309题买卖股票的最佳时机含冷冻期

题目&#xff1a; 题解&#xff1a; func maxProfit(prices []int) int {if len(prices) 0 {return 0}n : len(prices)f0, f1, f2 : -prices[0], 0, 0for i : 1; i < n; i {newf0 : max(f0, f2 - prices[i])newf1 : f0 prices[i]newf2 : max(f1, f2)f0, f1, f2 newf0, n…

【秋招笔试】24-07-27-OPPO-秋招笔试题(后端卷)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 💻 ACM金牌团队🏅️ | 多次AK大厂笔试 | 编程一对一辅导 ✨ 本系列打算持续跟新 秋招笔试题 👏 感谢大家的订阅➕ 和 喜欢💗 ✨ 笔试合集传送们 -> 🧷春秋招笔试合集 💡 01.二进制反转游戏 问题描述 K小姐…

体验教程:通义灵码陪你备战求职季

本场景将带大家体验在技术面试准备场景下&#xff0c;如何通过使用阿里云通义灵码实现高效的编程算法题练习 、代码优化、技术知识查询等工作&#xff0c;帮助开发者提升实战能力&#xff0c;更加从容地应对面试挑战。主要包括&#xff1a; 1、模拟题练习&#xff1a;精心挑选…

卸载Windows软件的正确姿势,你做对了吗?

前言 今天有小伙伴突然问我&#xff1a;她把软件都卸载了&#xff0c;但是怎么软件都还在运行&#xff1f; 这个问题估计很多小伙伴都是遇到过的&#xff0c;对于电脑小白来说&#xff0c;卸载Windows软件真的真的真的是一件很难的事情。所以&#xff0c;今天咱们就来讲讲&am…

2024最简七步完成 将本地项目提交到github仓库方法

2024最简七步完成 将本地项目提交到github仓库方法 文章目录 2024最简七步完成 将本地项目提交到github仓库方法一、前言二、具体步骤1、github仓库创建2、将远程仓库拉取并合并&#xff08;1&#xff09;初始化本地仓库&#xff08;2&#xff09;本地仓库与Github仓库关联&…

数据可视化工具,免费无限制制作报表

许多企业在报表制作上投入了大量资金&#xff0c;使用各种收费软件&#xff0c;往往只能满足基本需求&#xff0c;且操作复杂&#xff0c;让人感到无比头疼。不过最近我发现之前一直在做数据大屏的山海鲸可视化&#xff0c;现在新增了报表功能&#xff0c;不仅各种功能都可以免…

远程访问mysql数据库的正确打开方式

为了安全&#xff0c;mysql数据库默认只能本机登录&#xff0c;但是在有些时候&#xff0c;我们会有远程登录mysql数据库的需求&#xff0c;这时候应该怎么办呢&#xff1f; 远程访问mysql数据&#xff0c;需要两个条件&#xff1a; 首先需要mysql服务器将服务绑定到0.0.0.0…

【JVM】类加载器和双亲委派模型

什么是类加载器 如果想要了解什么是类加载器就需要清楚一个Java文件是如何运行的。我们可以看下图&#xff1a; 首先要知道操作系统是不能直接运行Java文件的&#xff0c;所以就需要通过JVM将Java文件转换为操作系统可以运行的文件类型&#xff0c;步骤如下&#xff1a; 类加…

入门必读:11个最受欢迎的UI设计网页软件详细评测

即时设计首发UI设计网页软件的轻量化和在线协作已成为当前UI设计网页软件的发展方向。网页UI设计不容易&#xff0c;实用的网页UI制图软件更难找到。随着网络的快速发展&#xff0c;网站迅速崛起&#xff0c;网页UI设计也很流行。UI设计网页软件即时设计是一种协同设计工具&…

MySQL是怎样运行的——第2章 启动选项和系统变量

文章目录 2.1 在命令行上使用选项2.1.1 选项的长形式和短形式 2.2 配置文件中使用选项2.2.1 配置文件的路径2.2.2 配置文件的内容2.2.3 配置文件的优先级 2.3 命令行和配置文件中启动选项的区别2.4 系统变量2.4.1 简介2.4.2 查看系统变量2.4.3 设置系统变量2.4.4 启动选项和系统…

PyQt5新手教程(五万字)

文章目录 PyQt界面开发的两种方式&#xff1a;可视化UI 编程式UI一、PyQt 简介二、PyQt 与 Qt 的蒙娜丽莎三、PyQt 布局管理器&#xff08;Layout Manager&#xff09;3.1、简介3.1.1、布局管理器的定义3.1.2、布局管理器的类型3.1.3、布局管理器的使用方法 3.2、项目实战3.2.…

查物流信息用什么软件

在电子商务日益繁荣的今天&#xff0c;快递物流信息的查询成为了我们日常生活中不可或缺的一部分。无论是网购达人还是商家&#xff0c;都需要随时掌握货物的物流动态。然而&#xff0c;如何快速、准确地查询物流信息却是一个令人头疼的问题。今天&#xff0c;我将为大家介绍一…

AI大模型需要什么样的数据?

数据将是未来AI大模型竞争的关键要素 人工智能发展的突破得益于高质量数据的发展。例如&#xff0c;大型语言模型的最新进展依赖于更高质量、更丰富的训练数据集&#xff1a;与GPT-2相比&#xff0c;GPT-3对模型架构只进行了微小的修改&#xff0c;但花费精力收集更大的高质量…

C++11深度剖析

目录 &#x1f680; 前言&#xff1a;C11简介 一&#xff1a; &#x1f525; 统一的列表初始化&#x1f4ab; 2.1 &#xff5b;&#xff5d;初始化 二&#xff1a; &#x1f525; std::initializer_list &#x1f4ab; 2.1 std::initializer_list是什么类型&#x1f4ab; 2.2 s…

正则采集器之三——前端搭建

前端使用有名的饿了么管理后台&#xff0c;vue3版本vue3-element-admin&#xff0c;首先从gitee中克隆一个vue3-element-admin模板代码vue3-element-admin: Vue3 Element Admin开箱即用的中后台管理系统前端解决方案&#xff0c;然后在此基础上进行开发。 1、修改vite.config.…

【C++】初识类和对象

本篇介绍一下C的自定义类型&#xff0c;类和对象。 1.类的定义 1.1 类定义格式 class 为定义类的关键字&#xff0c;Stack为类的名字&#xff0c;类名随便取&#xff0c;{}中为类的主体&#xff0c;类定义结束时后面的分号不可省略。类主体中内容称为类的成员&#xff1a;类中…

C++开源界面库duilib的使用细节与实战技巧总结(实战经验分享)

目录 1、使用CEditUI编辑框 2、使用CLabelUI或CTextUI的Html文本效果 3、使用CTextUI控件对文字宽度自适应的特性 4、CRichEditUI富文本控件使用注意点 4.1、指定CRichEditUI加在2.0版本的Riched20.dll库 4.2、解决向CRichEditUI中插入文字后显示空白的问题 5、设置窗口…