医学图像分割实战——使用U-Net实现肾脏CT分割

使用U-Net实现肾脏CT分割

  • 数据集准备
    • 数据来源
    • 数据预处理
  • 网络结构及代码
    • 网络结构
    • 训练代码
  • 训练过程
    • 参数设置:
    • 可视化
  • 结果分析

数据集准备

数据来源

MICCAI KiTS19(Kidney Tumor Segmentation Challenge):https://kits19.grand-challenge.org/

KiTS2019是MICCAI19的一个竞赛项目,项目的任务是对3D-CT数据进行肾脏和肾脏肿瘤的分割,官方的数据集提供了210个case作为训练集,90个case作为测试集。共有800多人报名参加了这一竞赛,最终提交的结果的team有126支,其中被认定有效的为100个记录入leaderboard。目前这一竞赛状态为开放性质的,有兴趣的可以参与一下。这一挑战将于2021年继续举办,将提供更多的数据集和标注,任务也将变得更加有挑战,有兴趣的同学可以跟进关注一下。

感谢评论里weixin_40621562老哥提供的数据集百度云版:链接: https://pan.baidu.com/s/1AOQDjPz9ye32DH-oDS0WDw 提取码: d7jk

数据预处理

KiTS19提供的数据是3D CT图像,我们要训练的是最简单的2D U-Net,因此要从3D CT体数据中读取2D切片。数据集的提供方在其Github上很贴心的提供了可视化的代码,是用python调用了nibabel库处理.nii格式的体数据得到2D的.png格式的切片。可视化的结果如下图所示,需要对切片进行筛选。另外需要补充的是在KiTS的数据集中分割的标签有三类:背景、肾脏、肾脏肿瘤,我们想进行的是简单的背景与肾脏二分类问题而不是多分类问题,因此在可视化过程中比较简单粗暴的将肿瘤视为肾脏的一部分。

一共有210个3D CT的体数据,从每个体数据中选取了10个slice,一共得到了2100张2D的.png格式图像。
在这里插入图片描述

网络结构及代码

网络结构

U-Net的结构时最简单的encoder-decoder结构,再加上越级连接,详细的网络结构请见我的另一篇博客深度学习图像语义分割网络总结:U-Net与V-Net的Pytorch实现。

训练代码

代码参照了github上的https://github.com/milesial/Pytorch-UNet

import argparse
import logging
import os
import sysimport numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdmfrom eval import eval_net
from unet import UNetfrom visdom import Visdom
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_splitdir_img = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\slice_png'
dir_mask = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\mask_png'
dir_checkpoint = 'checkpoints/'def train_net(net,device,epochs=5,batch_size=1,lr=0.1,val_percent=0.2,save_cp=True,img_scale=1):dataset = BasicDataset(dir_img, dir_mask, img_scale)n_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain, val = random_split(dataset, [n_train, n_val])train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)#writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')viz=Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))viz.line([0.], [0.], win='learning_rate', opts=dict(title='learning_rate'))viz.line([0.], [0.], win='Dice/test', opts=dict(title='Dice/test'))global_step = 0logging.info(f'''Starting training:Epochs:          {epochs}Batch size:      {batch_size}Learning rate:   {lr}Training size:   {n_train}Validation size: {n_val}Checkpoints:     {save_cp}Device:          {device.type}Images scaling:  {img_scale}''')optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)if net.n_classes > 1:criterion = nn.CrossEntropyLoss()else:criterion = nn.BCEWithLogitsLoss()for epoch in range(epochs):net.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:for batch in train_loader:imgs = batch['image']true_masks = batch['mask']assert imgs.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {imgs.shape[1]} channels. Please check that ' \'the images are loaded correctly.'imgs = imgs.to(device=device, dtype=torch.float32)mask_type = torch.float32 if net.n_classes == 1 else torch.longtrue_masks = true_masks.to(device=device, dtype=mask_type)masks_pred = net(imgs)#print('mask_pred',masks_pred.shape)#print('masks_pred',masks_pred.shape)#print('true_masks', true_masks.shape)viz.image(imgs, win='imgs/train')viz.image(true_masks, win='masks/true/train')viz.image(masks_pred, win='masks/pred/train')loss = criterion(masks_pred, true_masks)epoch_loss += loss.item()#writer.add_scalar('Loss/train', loss.item(), global_step)viz.line([loss.item()],[global_step],win='train_loss',update='append')pbar.set_postfix(**{'loss (batch)': loss.item()})optimizer.zero_grad()loss.backward()#nn.utils.clip_grad_value_(net.parameters(), 0.1)optimizer.step()pbar.update(imgs.shape[0])global_step += 1if global_step % (n_train // (10 * batch_size)) == 0:# for tag, value in net.named_parameters():#     tag = tag.replace('.', '/')#     writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)#     writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)val_score = eval_net(net, val_loader, device)scheduler.step(val_score)#writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)viz.line([optimizer.param_groups[0]['lr']], [global_step], win='learning_rate', update='append')if net.n_classes > 1:logging.info('Validation cross entropy: {}'.format(val_score))#writer.add_scalar('Loss/test', val_score, global_step)else:logging.info('Validation Dice Coeff: {}'.format(val_score))#writer.add_scalar('Dice/test', val_score, global_step)viz.line([val_score], [global_step], win='Dice/test', update='append')viz.image(imgs, win='images')if net.n_classes == 1:print('true_mask',true_masks.shape,true_masks.type)viz.image( true_masks, win='masks/true')print('pred',(torch.sigmoid(masks_pred) > 0.5).squeeze(0).shape)viz.images((torch.sigmoid(masks_pred) > 0.5),win='masks/pred')if save_cp:try:os.mkdir(dir_checkpoint)logging.info('Created checkpoint directory')except OSError:passtorch.save(net.state_dict(),dir_checkpoint + f'CP_epoch{epoch + 1}.pth')logging.info(f'Checkpoint {epoch + 1} saved !')#writer.close()def eval_net(net, loader, device):"""Evaluation without the densecrf with the dice coefficient"""net.eval()mask_type = torch.float32 #if net.n_classes == 1 else torch.longn_val = len(loader)  # the number of batchtot = 0with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:for batch in loader:imgs, true_masks = batch['image'], batch['mask']imgs = imgs.to(device=device, dtype=torch.float32)true_masks = true_masks.to(device=device, dtype=mask_type)with torch.no_grad():mask_pred = net(imgs)#['out']# if net.n_classes > 1:#     tot += F.cross_entropy(mask_pred, true_masks).item()# else:pred = torch.sigmoid(mask_pred)pred = (pred > 0.5).float()tot += dice_coeff(pred, true_masks).item()pbar.update()net.train()return tot / n_valdef get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,help='Number of epochs', dest='epochs')parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,help='Batch size', dest='batchsize')parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,help='Learning rate', dest='lr')parser.add_argument('-f', '--load', dest='load', type=str, default=False,help='Load model from a .pth file')parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,help='Downscaling factor of the images')parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,help='Percent of the data that is used as validation (0-100)')return parser.parse_args()if __name__ == '__main__':logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')args = get_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixel#   - For 1 class and background, use n_classes=1#   - For 2 classes, use n_classes=1#   - For N > 2 classes, use n_classes=Nnet = UNet(n_channels=1, n_classes=1, bilinear=True)logging.info(f'Network:\n'f'\t{net.n_channels} input channels\n'f'\t{net.n_classes} output channels (classes)\n'f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')if args.load:net.load_state_dict(torch.load(args.load, map_location=device))logging.info(f'Model loaded from {args.load}')net.to(device=device)# faster convolutions, but more memory# cudnn.benchmark = Truetry:train_net(net=net,epochs=args.epochs,batch_size=args.batchsize,lr=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100)except KeyboardInterrupt:torch.save(net.state_dict(), 'INTERRUPTED.pth')logging.info('Saved interrupt')try:sys.exit(0)except SystemExit:os._exit(0)

训练过程

参数设置:

训练与验证比例: 8:2 (1680:420)
batch_size: 2
学习率:torch.optim.lr_scheduler.ReduceLROnPlateau,当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能损失函数:BCEWithLogitsLoss 衡量目标和输出之间的二进制交叉熵

可视化

使用visdom进行可视化
一开始的训练状态,左边为真实的mask,右边为网络的输出,可以看到一开始网络的输出还是不太行的。
在这里插入图片描述
当进行完第一轮训练之后训练的结果如图所示,红色所框的为训练过程,蓝色所框为验证过程,包括了原图、真实的mask T、预测的mask P。训练和验证过程中预测mask的差异来自于是否进行了二值化处理。
在这里插入图片描述
第四轮训练之后的结果,预测的mask与真实的mask已经很接近了。
在这里插入图片描述

结果分析

在这里插入图片描述
实验结果:Dice系数:0.832
结果分析:
1.原数据为三维,本次实验只使用的二维切片
2.原数据的mask肿瘤和肾脏是分开的,在数据处理过程中统一化为了肾脏。
3.没有做数据增强、参数调整,训练不够充分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/48328.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

医学图像分割之 Dice Loss

文章目录 医学图像分割之 Dice Loss1. Dice coefficient 定义1.1. Dice 系数计算示例1.2. Dice-coefficient loss function vs cross-entropy 2. Dice 系数的 Pytorch 实现2.1. Dice 系数2.2. Dice Loss2.3. BCELoss2d 3. Dice 系数的 Keras 实现4. Dice 系数的 TensorFlow 实现…

医学图像分割常见评价指标(单目标)——包含源码讲解和指标缺陷

单目标分割常见评价指标 1 知道4个常见指标,TP,TN,FP,FN2 评价分割区域准确率2.1 Recall Sensitivity TPR(True Positive Rate)2.2 Specificity (True Negative Rate)2.3 Precision (PPV, 精确率)2.4 Dice Coefficient…

医学图像分割之Attention U-Net

目录 一、背景 二、问题 三、解决问题 四、Attention U-Net网络结构 简单总结Attention U-Net的操作:增强目标区域的特征值,抑制背景区域的目标值。抑制也就是设为了0。 一、背景 为了捕获到足够大的、可接受的范围和语义上下文信息,在标…

常用的医学图像分割评价指标

常用的图像分割评价指标非常多,论文中常用的指标包括像素准确率,交并比(IOU),Dice系数,豪斯多夫距离,体积相关误差。 下面提到的所有案例都是二分类,标签中只有0和1 目录 一:像素…

医学图像分割评判标准及程序代码

文章目录 1.图像分割指标2. 两个问题3.IOU和假阳性率4. 准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure 参考资源: 1.https://blog.csdn.net/zichen_ziqi/article/details/80408465 2.https://blog.csdn.net/HXG2006/article/details/79649154 …

基于Android studio开发的图灵智能聊天机器人

前言 在人工智能时代,开发一款自己的智能问答机器人,既可以提升自己的编程能力,又可以作为开发项目的实战练习。 百度有小度,小米有小爱,VIVO有小V,总之类似的智能聊天机器人是越来越多了。面对这些智能的机…

短视频矩阵源码开发部署--开原

短视频矩阵源码是一种常见的视频编码标准,它通过将视频分成多个小块并对每个小块进行压缩来实现高效的视频传输。在本文中,我们将介绍短视频矩阵的原理和实现,并提供示例代码。 开发链路解析 短视频矩阵系统源码开发链路包括需求分析、技术…

ChatGPT危了!注意力机制的神秘bug曝光!Transformer模型恐大受冲击...

点击下方卡片,关注“CVer”公众号 AI/CV重磅干货,第一时间送达 点击进入—>【Transformer】微信交流群 转载自:新智元 【导读】「注意力公式」存在8年的bug首现,瞬间引爆舆论。爆料者称,基于Transformer架构打造的模…

短视频抖音seo矩阵源码如何搭建开发?

抖音SEO矩阵源码排名逻辑采用一系列算法进行生成,其中包括用户行为、关键词匹配和内容质量等多维度指标的衡量。首先,用户行为是决定视频排名的主要因素,包括点赞数、评论数、观看时长和转发次数等。其次,关键词匹配也是影响排名的…

chatgpt赋能python:使用Python让照片动起来:一种新颖的SEO方法

使用Python让照片动起来:一种新颖的SEO方法 在当今数字时代,社交媒体已经成为营销策略中不可或缺的一部分。人们越来越喜欢以图像的形式来获取信息。然而,在面对大量的图像时,如何让自己的图片和品牌脱颖而出?答案是&…

chatgpt赋能python:PythonWand:用Python实现的ImageMagick工具箱

Python Wand: 用Python实现的ImageMagick工具箱 ImageMagick是一款强大的图像处理工具箱,经常被用于缩放、裁剪和转换图像等任务。Python Wand是对ImageMagick命令行工具的Python封装,使得Python程序员能够使用Python代码来操作图像。 为什么使用Pytho…

短视频如何进行高效制作?元引擎助你一臂之力

在当今社会,视频制作已经成为了一种非常流行和重要的创意方式。越来越多的人开始尝试制作自己的短视频,但是对于很多新手小白来说,短视频制作可能是一项相对困难的任务。但是现在,使用元引擎AI一键生成原创视频系统,可…

Python预测彩票中奖

文章目录[隐藏] python来解答你有生之年可以中双色球 python来解答你有生之年可以中双色球 昨天买了几注双色球开奖了,规划好了中奖后怎么花,紧张又刺激的等待后,狗带…… 到底我们能不能中双色球呢,用Python来验证一下吧&#xf…

基于GPT-4的 IDEA 神仙插件,无需魔法,非常不错!

大家好,我是不才陈某~ 最近发现了一款很厉害的 Intellij IDEA 插件——Bito。 Bito 插件无需魔法,亲测有效,可以基于 GPT-4 来写代码同时还提供了一些有用的功能,如自动补全提交信息、快速查看历史记录等。 没使用魔法的情况下&am…

IDEA懒人必备插件:自动生成单元测试,太爽了!

程序员的成长之路 互联网/程序员/技术/资料共享 关注 阅读本文大概需要 7 分钟。 来自:blog.csdn.net/sun5769675/article/details/111043213 今天来介绍一款工具Squaretest,它是一款自动生成单元测试的插件,会用到它也是因为最近公司上了代…

Mac Automator 图片自动压缩上传 COS

把个人博客放在了 netilfy 托管,它给了一个优化建议,可以压缩图片节省带宽。但是每次截图后都要再压缩下图片有点麻烦,于是想着应该可以偷偷懒。目标很明确,截图后图片传到我的 git 仓库 images 目录后,能给我自动压缩…

IDEA 28 个天花板技巧,YYDS!

因公众号更改推送规则,请点“在看”并加“星标”第一时间获取精彩技术分享 点击关注#互联网架构师公众号,领取架构师全套资料 都在这里 0、2T架构师学习资料干货分 上一篇:ChatGPT研究框架(80页PPT,附下载)…

CSDN 去除图片水印

想要保存 CSDN 博客中的一张图片时,发现图片上有水印,想要删除,怎么办呢? 如下图 右击图片 --> [在新标签中打开图片] 把问号以及问号后面的内容全部删掉,再访问 呐,水印不见了。 PS:写博…

免费的图片去水印消除水印清除水印去水印方法去水印软件免费下载

是一款免费的图片去水印工具。适用于微博下载的(偷的)图,从别的地方下载的(偷的)图等。 ** 直接说下载链接:请点击链接里的普通下载,(其他是别的下载器) 当然&#xff…

批量图片去水印,操作简单,赶紧收藏!

图片怎么去水印?在平时的日常生活中,我们有时候需要用到一些图片或者视频,但是这些视频或者图片往往会有烦人的水印,我们需要去除水印后才能更好来使用~那么你都是怎么去除水印的呢?有什么比较简单轻松的批量图片去水印方法吗?小编这里有一…