在医疗图像分割任务中,transformer模型获得了巨大的成功,UNETR提出了efficient paired attention (EPA) 模块,利用了空间和通道注意力来有效地学习通道和空间的特征,该模型在Synapse,BTCV,ACDC,BRaTs数据集上都获得了很好地效果。
论文:UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
代码:https://github.com/Amshaker/unetr_plus_plus
一、论文笔记
首先看一下模型架构,整体还是UNet结构,在其中引入了提出的EPA模块。
该论文的核心就是EPA模块,EPA的提出主要是解决2个问题:
1、计算更有效率:传统的self-attention计算成本很高,对于3D的医疗图像来说更高,EPA将self-attention的K和V投影到低纬度再计算,降低了计算复杂度;
2、增强了空间和通道特征表示能力:transformer本身就是一种空间注意力机制,但是它忽略了通道特征,EPA将空间和通道特征融合在了一起。
再仔细看一下EPA的结构图,上方蓝底部分式空间注意力,下方绿底部分式通道注意力。再空间注意部分,为了降低self-attention计算量,将HWDXC的K和V降维到pXC维度。
代码如下(类中的self.EF用于降低K和V的维度,空间注意力和通道注意力的K和Q是共享的):
class EPA(nn.Module):"""Efficient Paired Attention Block, based on: "Shaker et al.,UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation""""def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,channel_attn_drop=0.1, spatial_attn_drop=0.1):super().__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)# E and F are projection matrices with shared weights used in spatial attention module to project# keys and values from HWD-dimension to P-dimensionself.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size)))self.attn_drop = nn.Dropout(channel_attn_drop)self.attn_drop_2 = nn.Dropout(spatial_attn_drop)def forward(self, x):B, N, C = x.shapeqkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)qkvv = qkvv.permute(2, 0, 3, 1, 4)q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]q_shared = q_shared.transpose(-2, -1)k_shared = k_shared.transpose(-2, -1)v_CA = v_CA.transpose(-2, -1)v_SA = v_SA.transpose(-2, -1)proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args)k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF)))q_shared = torch.nn.functional.normalize(q_shared, dim=-1)k_shared = torch.nn.functional.normalize(k_shared, dim=-1)attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperatureattn_CA = attn_CA.softmax(dim=-1)attn_CA = self.attn_drop(attn_CA)x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2attn_SA = attn_SA.softmax(dim=-1)attn_SA = self.attn_drop_2(attn_SA)x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)return x_CA + x_SA@torch.jit.ignoredef no_weight_decay(self):return {'temperature', 'temperature2'}
二、代码实践
官方给出了Synapse,BTCV,ACDC,BRaTs数据集的跑通实例,我这里只跑一个BRaTs数据集,其他的是一样的步骤。
1、安装环境
使用conda安装环境:
conda create --name unetr_pp python=3.10
conda activate unetr_pp
安装torch:
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
安装依赖:
pip install -r requirements.txt
2、准备数据集
官方给出了处理好的数据集地址,直接下载即可:
数据集 | 链接 |
Synapse | OneDrive |
ACDC | OneDrive |
Decathon-Lung | OneDrive |
BRaTs | OneDrive |
本文下载好了BraTs数据集作为实例,将其放入以下目录:
3、训练
因为只是跑通一下,把unetr_plus_plus/unetr_pp/training/network_training/unetr_pp_trainer_tumor.py中的epoch改成10:
训练就非常简单了,进入训练集脚本目录并运行脚本:
cd training_scripts
bash run_training_tumor.sh
训练起来了:
4、评估
首先将自己训练的权重放到指定位置(原来output_tumor的unetr_pp文件夹放到unetr_plus_plus\unetr_pp\evaluation\unetr_pp_tumor_checkpoint里面去):
修改代码unetr_plus_plus/unetr_pp/inference/predict.py,共有两处:
进入评估脚本目录并运行脚本:
cd evaluation_scripts
修改run_evaluation_tumor.sh脚本,相关路径替换为自己的路径(自带的脚本我没成功,大家可以自行尝试):
#!/bin/shDATASET_PATH=../DATASET_Tumorexport PYTHONPATH=.././
export RESULTS_FOLDER=../unetr_pp/evaluation/unetr_pp_tumor_checkpoint
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw# Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without trainingpython /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference/predict_simple.py -i ../DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o ../unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumorpython /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference_tumor.py 0
修改unetr_plus_plus/unetr_pp/inference_tumor.py的数据集路径,可以根据自己的情况改:
运行脚本:
bash run_evaluation_tumor.sh
在推理结果的目录unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/下多了一个
dice_five.txt文件,里面有相关精度,如下(因为就训练了10个epoch,效果不行):
本文到此结束。