似乎是第一个用于医学图像分类的扩散模型嗷~
论文:DiffMIC: Dual-Guidance Diffusion Network for Medical Image Classification
代码:https://github.com/scott-yjyang/DiffMIC
0、摘要
扩散概率模型最近在生成式图像建模中表现出了显著的性能,引起了计算机视觉界的广泛关注。(顶流,绝对的顶流~)然而,大量的基于扩散的研究集中在生成任务上,很少有研究将扩散模型应用于一般的医学图像分类。(我偏要剑走偏锋~)
本文提出了第一个基于扩散的模型(称为 DiffMIC)来实现医学图像分类,DiffMIC 能够消除医学图像中的意外噪声和扰动,并稳健地捕获语义表示。
为此,本文设计了一种双重条件引导策略,通过多个粒度对每个扩散步骤进行条件化,以改善逐步区域注意力。此外,提出了通过在扩散前向过程中执行最大均值差异正则化来学习每个粒度中的互信息。
本文评估了 DiffMIC 在三种不同图像类型的医学分类任务上的有效性,包括胎盘成熟度分级(超声),皮肤病变分类(皮肤镜),糖尿病视网膜病变分级(眼底照相),具有 SOTA 结果。
1、引言
1.1、现有分类方法挑战
(1)不同类型医学图像存在各种模糊的病变和细粒度组织;
(2)硬件限制下生成医学图像会产生噪声和模糊效应,降低图像质量,故需要更有效的特征表示建模来进行鲁棒分类;
1.2、本文贡献
(1)提出了基于扩散模型的医学图像分类模型;(第一个!)
(2)引入一种双粒度条件引导(DCG)策略来指导去噪过程,在扩散过程中同时使用全局和局部先验来调节每个 step;
(3)引入特定条件的最大平均差异(MMD)正则化来学习每个粒度在潜在空间中的互信息;
(4)评估了 DiffMIC 在 3 个 2D 医学图像分类任务上的有效性;
2、方法
图1 为 DiffMIC 的示意图。给定输入图像 x x x,将其传入图像编码器获得图像特征嵌入 ρ ( x ) {ρ(x)} ρ(x),利用双粒度条件引导(DCG)模型来生成全局先验 y ^ g {\hat y_g} y^g 和局部先验 y ^ l {\hat y_l} y^l。
训练时,对 ground truth y 0 {y_0} y0 和不同的先验进行扩散过程,得到三个噪声变量:全局先验 y t g {y^g_t} ytg,局部先验 y t l {y^l_t} ytl,双先验 y t {y_t} yt。随后,将三个噪声变量及其各自的先验相结合,并将它们分别投影到一个潜在空间中。进一步将三种投影嵌入与图像特征嵌入 ρ ( x ) {ρ(x)} ρ(x) 相结合,并预测了从 y t g {y^g_t} ytg, y t l {y^l_t} ytl 和 y t {y_t} yt 采样的噪声分布。
本文针对预测的噪声 y t g {y^g_t} ytg 和 y t l {y^l_t} ytl 设计了特定条件的最大均值差异(MMD)正则化损失,并采用均方误差(MSE)衡量噪声估计损失,3 个损失进行协同训练,以优化 DiffMIC 网络。
Figure 1 | DiffMIC 框架概述:(a)训练和(b)推理过程(颜色越深,特征嵌入的噪声就越大),(c)DCG 模型通过来自原始图像和 ROI 的双重先验来指导扩散过程;
2.1、双粒度条件引导(DCG)策略
(1)DCG模型
在大多数条件 DDPM 中,条件先验将是一个唯一的给定信息。然而,由于对象的模糊性,医学图像分类尤其具有挑战性,很难从背景中区分病变和组织,特别是在低对比度的图像模式中,如超声图像。
此外,在感兴趣的区域(ROI)中可能存在意外的噪声或模糊效应,从而阻碍了对高级语义的理解。在每个扩散步骤中,只取一个原始图像 x x x 作为条件,将不足以稳健地学习细粒度信息,从而导致分类性能下降。
为缓解该问题,本文设计了一个双粒度条件引导(DCG)来编码每个扩散步骤。具体而言,引入了一个 DCG 模型 τ D τ_D τD 来计算扩散过程的全局和局部条件先验。与放射科医生的诊断过程类似,可以从全局先验中获得一个整体的理解,在消除负噪声影响时,也可以集中于局部先验中病变对应的区域。
如 图1(c) 所示,对于全局流,将原始图像数据 x x x 输入全局编码器 τ g τ_g τg ,然后输入 1×1 的卷积层,生成整个图像的显著性映射。然后通过平均池化,从整个显著性映射中预测全局先验 y ^ g {\hat y_g} y^g。对于局部流,进一步裁剪在整个图像的显著性映射中响应显著的 ROI。每个 ROI 被输入局部编码器 τ l τ_l τl,以获得一个特征向量。然后利用门控注意机制,融合 ROI 中的所有特征向量,得到一个加权向量,利用该向量通过一个线性层计算局部先验 y ^ l {\hat y_l} y^l。
(2)去噪模型
基于 DCG 模型计算的全局和局部先验,在扩散过程中对噪声变量 y t {y_t} yt 进行如下采样:
其中, ϵ ∼ N ( ϵ ; 0 , I ) {ϵ \sim \mathcal N(ϵ; 0, I)} ϵ∼N(ϵ;0,I), α ˉ t = ∏ t α t {{\bar \alpha _t} = \prod _{t}{\alpha _t}} αˉt=∏tαt, α t = 1 − β t {{\alpha _t} = 1 - {\beta _t}} αt=1−βt 具有线性噪声调度 { β t } t = 1 : T ∈ ( 0 , 1 ) T {\{ \beta _t \}_{t=1:T} \in (0,1)^T} {βt}t=1:T∈(0,1)T。随后,将噪声变量 y t y_t yt 和双先验的连接向量输入去噪模型 UNet ϵ θ ϵ_θ ϵθ 来估计噪声分布,其公式为:
式中, f ( ⋅ ) {f(·)} f(⋅) 为对潜在空间的投影层。 [ ⋅ ] {[·]} [⋅] 是连接操作。 E ( ⋅ ) {E(·)} E(⋅) 和 D ( ⋅ ) {D(·)} D(⋅) 是 UNet 的编码器和解码器。需要注意的是,图像特征嵌入 ρ ( x ) {ρ(x)} ρ(x) 与 UNet 中的投影噪声嵌入进一步集成,使模型关注高级语义,从而获得更鲁棒的特征表示。在前向过程中,利用最小化噪声估计损失 L ϵ {\mathcal L_ϵ} Lϵ:
本文的方法通过结合从原始图像和感兴趣区域(ROIs)中提取的信息作为先验条件,来优化每一步的估计函数,从而改进了基本的扩散模型。
2.2、条件特异性 MMD 正则化
最大均值差异(MMD)是通过比较两个分布的所有矩来衡量它们之间的相似性,这可以通过核函数有效地实现。受 InfoVAE 启发,本文引入了一对额外的条件特定最大均值差异(MMD)正则化损失,以学习采样噪声分布与高斯分布之间的互信息。
具体而言,从时间步长 t {t} t 的扩散过程中采样噪声变量 y t g {y_t^g} ytg,然后计算一个 MMD 正则化损失为:
其中, K ( ⋅ , ⋅ ) {\mathbb K(·, ·)} K(⋅,⋅)是一个正定核函数,用于在希尔伯特空间中再生分布。特定条件的 MMD 正则化也应用于局部先验,如 图1(a)所示。一般的噪声估计损失 L ϵ {\mathcal L_ϵ} Lϵ 从两个先验中捕获互补信息,而特定条件的 MMD 正则化保持了每个先验和目标分布之间的互信息。这有助于网络更好地建模双先验共享的鲁棒特征表示,并以稳定的方式更快地收敛。
2.3、训练和推理方案
(1)总损失
通过添加噪声估计损失和 MMD 正则化损失,去噪网络的总损失 L d i f f {\mathcal L_{diff}} Ldiff 如下:
其中 λ {λ} λ 是一个平衡超参数,根据经验设置为 λ = 0.5 {λ=0.5} λ=0.5。
(2)训练细节
①采用标准的 DDPM 训练过程,扩散时间步长 t {t} t 从 [ 1 , T ] {[1,T]} [1,T] 的均匀分布中选择,噪声用 β 1 = 1 × 1 0 − 4 {β_1=1×10^{−4}} β1=1×10−4 和 β T = 0.02 {β_T=0.02} βT=0.02 进行线性调度;
②采用 ResNet18 作为图像编码器 ρ ( ⋅ ) {ρ(·)} ρ(⋅);
③连接 y ^ g {\hat y_g} y^g, y ^ l {\hat y_l} y^l 和 y t {y_t} yt,并应用一个输出维数为 6144 的线性层,得到潜在空间中的融合向量;
④为了根据时间步长对响应嵌入进行条件化处理,本文对融合向量和时间步长嵌入执行哈达玛积。然后,对图像特征嵌入和响应嵌入再次执行哈达玛积,将它们整合在一起;输出向量被依次传递通过两个全连接层,每个全连接层之后都与一个时间步长嵌入执行哈达玛积。最后,使用一个全连接层来预测噪声,该层的输出维度等于类别数。值得注意的是,除了输出层之外,所有全连接层都伴随着一个批量归一化层和一个Softplus非线性激活函数。
⑤对于 DCG 模型 τ D τ_D τD,其全局和局部流的主干是 ResNet,采用标准的交叉熵损失作为 DCG 模型的目标。对 DCG 模型进行 10 个 epoch 的预热预训练后,联合训练去噪扩散模型和 DCG 模型,从而得到一个用于医学图像分类的端到端 DiffMIC。
(3)推理阶段
如 图1(b)所示,给定一个输入图像 x x x,首先将其输入到 DCG 模型中,以获得双先验 y ^ g {\hat y_g} y^g, y ^ l {\hat y_l} y^l。随后,根据 DDPM 的 pipeline,将先验 y ^ g {\hat y_g} y^g, y ^ l {\hat y_l} y^l 和图像特征嵌入 ρ ( x ) ρ(x) ρ(x) 输入训练好的条件 UNet,从随机预测 y T y_T yT 中迭代去噪得到最终的预测 y ^ 0 {\hat y_0} y^0。
3、实验
3.1、数据集
(1)PMG2000:胎盘成熟度分级(4类),2098 张超声图像,8:2 划分;
(2)HAM10000:皮肤病变分析(7类),10015 张皮肤病变图像,7:3 划分;
(3)APTOS2019:糖尿病视网膜病变分析(5类),3662 张眼底图像,7:3 划分;
3.2、实施细节
(1)PyTorch框架,NVIDIA RTX 3090 GPU;
(2)中心裁剪 224×224,随机翻转和旋转,6 个 ROI 区域 32×32;
(3)Adam 优化器,batch size=32;
(4)UNet 初始学习率 1 × 1 0 − 3 {1×10^{−3}} 1×10−3,DCG 训练整个网络的学习率 2 × 1 0 − 4 {2×10^{−4}} 2×10−4;
(5)三个数据集训练 epoch=1000;
(6)推理的总扩散时间步长 T T T,PMG2000: T = 100 T =100 T=100,HAM10000: T = 250 T =250 T=250, APTOS2019 : T = 60 T =60 T=60,对空间分辨率为 224×224 的图像进行分类,DiffMIC 的平均运行时间约为 0.056 秒;
3.3、实验结果
(1)与先进方法比较
HAM10000 和 APTOS2019 都存在类不平衡问题。因此,本文将 DiffMIC 与最先进的长尾医学图像分类方法进行了比较;
Table 1 | 与 SOTA 方法在三个分类任务上的定量比较:
(2)消融实验
basic 为经典的 ResNet18,C1 为添加基本扩散过程,C2为进一步添加双粒度条件引导;
Table 2 | DiffMIC 中每个模块在 PMG2000 数据集上的有效性:
(3)扩散过程可视化
随着时间步长编码的进行,去噪扩散模型逐渐从特征表示中去除噪声,从而使得类别从高斯分布中的分布更加清晰;
Figure 2 | 在三个数据集的推理中,通过扩散反向过程去噪特征嵌入得到 t-SNE:
感觉还可以做大做强,再创辉煌~