[论文笔记]Triplet attention and dual-pool contrastive learning for clinic-driven multi-label medical image classification
Abstract
多标签分类Multi-label classification (MLC)可在单张图像上附加多个标签,在医学图像上取得了可喜的成果。但现有的多标签分类方法在实际应用中仍面临着严峻的临床现实挑战,例如:
- 错误分类带来的医疗风险,
- 不同疾病之间的样本不平衡问题
- 无法对未预先定义的疾病(未见疾病)进行分类
设计了一种混合标签,以提高 MLC 方法的灵活性,并缓解样本不平衡问题。具体来说,在标注训练集中,我们为有足够样本的高频疾病保留独立标签,并使用混合标签合并样本较少的低频疾病。混合标签还可用于未见疾病的实际应用。
基于上述标签表示法,提出了用于多标签医学图像分类的三重注意和双池对比学习(TA-DCL)。 - TA-DCL 架构是一个三重注意网络(TAN),它将类别注意、自我注意和交叉注意结合在一起,通过挖掘医学图像中的有效信息,为所有疾病标签学习高质量的标签嵌入。
- DCL 包括双库对比训练(DCT)和双库对比推理(DCI)。DCT 优化了属于不同疾病标签的标签嵌入的聚类中心,从而提高了标签嵌入的分辨能力。DCI 可减轻病例分类错误,降低临床风险,并通过差异对比提高检测未见疾病的能力。
1.Introduction
由于大多数医学图像通常包含多种疾病,因此如图 1 所示,将每幅医学图像与多个疾病标签关联起来更为实用。为每种疾病训练独立的二进制单标签分类器既昂贵又困难。因此,开发能解决医学图像多标签分类(MLC)问题的深度学习技术是真实临床场景中更实际的需求。Transformer 在自然语言处理、计算机视觉、医学图像分析等许多基于自注意和交叉注意的机器学习领域取得了巨大成功。Transformer 也被用于促进 MLC 的发展。基于变换器的 MLC 方法可以通过对标签嵌入和图像特征之间的相互作用进行建模来学习有效的标签嵌入,并一次性对每个标签进行独立预测。
有几个临床现实问题往往容易被忽视:
- 在样本库中,自动系统通常会过滤非紧急病例,并挑选出紧急病例供临床医生优先审阅。如果紧急病例被误判为非紧急病例,错过最佳治疗时机,疾病可能会恶化。相反,如果一个非紧急病例被误判为紧急病例,则可通过临床医生的优先审查及时纠正这一错误。直观地说,前者比后者风险更大,因此自动系统应尽量避免将紧急病例错误分类。
- ==很容易收集到大量高发疾病的样本,但却无法收集到低发疾病的样本。==这将导致严重的样本不平衡问题,造成模型训练的偏差。
- 由于疾病种类繁多,很难为医学图像分类任务预先定义一个完整的疾病标签集。在实际应用中,难免会出现未见过的疾病(不包括在训练集中),而自动方法无法将其归入任何预定义的标签中。
多标签医学影像分类的三重注意和双池对比学习(TA-DCL):
在标注训练集中,根据样本量将所有疾病分为高频疾病和低频疾病,然后为样本量足够多的高频疾病保留独立标签,将样本量较少的低频疾病合并为混合标签。这种标签表示法可以缓解样本不平衡问题,提高 MLC 方法的灵活性。
TA-DCL 架构是一个三重注意网络(TAN),由**图像特征提取器(IFE)和三重注意转换器(TAT)**组成。
- IFE 首先将输入的医学图像转换为图像空间特征和类别注意特征
- 三重注意转换器(TAT)通过类别注意特征强化类别的重要性,并通过自我注意和交叉注意与图像空间特征紧密交互来更新标签嵌入
- 提出双池对比训练(DCT),以优化属于不同标签的标签嵌入的聚类中心,从而进一步提高标签嵌入的辨别能力
- 提出双池对比推理(DCI),以减轻病例分类的误差,提高检测未见疾病的能力。
-
Contributions
- 针对临床驱动的多标签医学影像分类提出了三重注意和双池子对比学习(TA-DCL),在 MLC 方法中考虑了临床实际情况。为了提高 TA-DCL 对临床场景的适应性,我们设计了灵活的标签表示
- 提出了三重注意网络(TAN),它结合了类别注意、自我注意和交叉注意,通过挖掘医学图像中的有效信息来学习高质量的标签嵌入
- 提出了双池对比训练(DCT),通过学习差异来提高标签嵌入的辨别能力
- 提出了双池对比推理(DCI),通过差异对比来降低错误分类造成的临床风险并提高检测未见疾病的能力
- 在两个公共医疗图像数据集上验证了 TA-DCL,结果显示其整体性能优于其他 MLC 方法。在实际应用中,TA-DCL 对未知疾病的适应性也更强
2. Related work
2.1. Multi-label classification
2.1.1. Technical progress
由于 MLC 必须为每个类别训练一个二进制分类器,因此样本不平衡问题成为一个巨大的挑战。最近的一些研究试图修改损失函数,使模型在训练过程中达到动态平衡,如非对称损失和分布平衡损失。标签之间的依赖关系可视为指导模型推理的先验知识,基于图的方法是应用最广泛的方法。
静态图可以在一定程度上改善 MLC,但统计标签共现严重依赖大样本数据,并不总是可靠的。为了缓解静态图的缺点,人们引入了可在训练过程中自动更新的动态图来替代静态图,Transformer 将标签的全局依赖性建模为全连接图。定位图像中与标签相对应的注意力区域对 MLC 有着直观的好处。
- Gao 和 Zhou 使用全局到局部的学习机制来检测候选局部区域,然后提出了一个注意区域模块来保护这些区域的差异。
- You 等人提出了基于邻接的相似性图嵌入模块和跨模态注意模块,以捕捉图像区域和标签嵌入之间的依赖关系。
- 顺序条件预测、结构化输出推理表述和共享嵌入空间学习等也被用于 MLC 的改进。
2.1.2. Applications for medical image analysis
近年来,MLC 已广泛应用于 X 光图像。还应用于其他医学模式,如彩色眼底图像、核磁共振图像、彩色皮肤图像等。
- Chen 等人推进了语义空间中的图推理,对不同病理之间的标签共现和相互依赖关系进行建模。
- Luo 等人通过挖掘外部数据集的附加知识,提出了一种多标签 X 射线图像筛选框架。
- Agu 等人基于 GNN 对标签依赖性和解剖信息之间的关系进行了建模。
- Chen 等人提出了一种新颖的语义相似性图嵌入框架,该框架明确探索图像之间的语义相似性,以优化视觉嵌入。
- Wang 等人将 MLC 转化为每个标签的二元分类问题,然后使用迁移学习和集合学习来整合多个弱分类器,以获得更好的性能。
- Lin 等人使用 GNN 捕捉相关信息,然后使用自监督学习增强泛化能力。
- Zhang 等人设计了一种多标签分类器,用于在个体的核磁共振结构图上标注疾病状态,并通过多域学习方案进一步改进了分类器
- Tang 等人提出了一种两阶段多模态学习方法,通过在不同阶段整合不同模态的信息来进行多标签皮肤病分类。
这些 MLC 方法会产生固定的标签表示,在实际应用中并不灵活
2.2. Transformer for computer vision
2.2.1. Technical progress
Transformer是一种重要的深度学习方法,用于对序列特征的全局依赖性进行建模,已被广泛应用于自然语言处理(NLP)领域。如今,Transformer 也已成功用于计算机视觉任务,并取得了令人瞩目的成果。
- 针对图像分类任务,Dosovitskiy 等人提出了视觉变换器(Vision Transformer,ViT),将每幅图像分成若干局部补丁,所有局部补丁作为序列特征送入堆叠变换器;
- Yuan 等人提出了标记-标记视觉变换器(Token-To-Token Vision Transformer,T2T-ViT),以保留被补丁标记化破坏的重要局部结构;
- Srinivas 等人用多头自注意取代了 ResNet 末端的卷积层,以捕捉更好的全局依赖性。
此外,Transformer 在物体检测、图像分割以及其他一些应用方面也表现出了卓越的性能。
2.2.2. Applications for medical image analysis
在医学图像分析任务方面,Transformer也取得了重大进展。
- Lee 等人提出了模板变换器网络(TTN),将变换器和形状先验相结合,用于冠状动脉管腔结构的分割
- Song 等人提出了用于青光眼诊断的深度关系变换器(DRT),其中使用深度推理机制对 OCT 图像和视野信息之间的全局和局部关系进行建模,并使用变换器进一步增强表征
- Lu 等人提出了轮廓变换器网络(Contour Transformer Network,CTN)来模拟轮廓演变的过程和行为,用于 X 射线图像上解剖结构的一次分割。
- He 等人将全局和局部信息整合到变换器中,根据核磁共振图像估计脑年龄。
- Huang 等人提出了一个关系变换器块和一个全局变换器块,以结合注意力机制并保留糖尿病视网膜病变分割的详细信息
- You 等人提出了一种用于二维医学图像分割的有效对抗变换器,其中设计了类感知变换器模块和对抗训练策略,以更好地学习具有语义结构的物体的判别区域。
2.3. Contrastive learning for medical images
最近,对比学习也被广泛应用于医学图像处理领域
对比学习可以通过优化正对和负对之间的相似性来学习有效的表征
- SimCVD描述了一种无监督训练策略,它采用输入体的两个视图,以对比为目标,预测其对象边界的符号距离图,只有两个独立的滤除作为掩码
- You 等人引入了动量对比,以确保在三维表征维度中促进特征多样性
- Chaitanya 等人提出了一种对比学习框架,以分阶段的方式提取全局和局部线索
- You 等人开发了一种迭代对比提炼算法,通过对负值进行软标记,而不是在正负对之间进行二元监督
3. Preliminary
3.1. Multi-label classification
给定一幅医学图像 𝒙 和 𝐿 个预定义的疾病标签,MLC 要预测每种疾病是否出现在 𝒙 中。𝒙的多标签基本事实可表示为
Y
=
[
y
1
,
y
2
,
.
.
.
,
y
L
~
]
Y=[y_1,y_2,...,y_{\widetilde L}]
Y=[y1,y2,...,yL
],其中y𝑙∈ {0, 1} 是离散的二进制指标。如果出现疾病-𝑙,则y𝑙=1,否则y𝑙=0。MLC 通过学习分类器 F (⋅)来预测Y中出现的所有疾病的概率,尽可能接近基本事实𝐘:
3.2. Our label representation
为了使标签表示在实际应用中更加灵活,重新设计了𝐘。
在有标签的训练集中,我们根据样本量将所有疾病分为高频疾病和低频疾病,然后对有足够样本的高频疾病保留独立标签,对样本较少的低频疾病使用混合标签合并。the ground truth可以改写为
Y
=
[
y
1
,
y
2
,
.
.
.
,
y
L
~
]
Y=[y_1,y_2,...,y_{\widetilde L}]
Y=[y1,y2,...,yL
] ,其中
L
~
=
L
−
L
l
f
+
1
\widetilde L=L-L_{lf}+1
L
=L−Llf+1,
L
l
f
L_{lf}
Llf是低频疾病的数量。
{
y
l
}
l
=
1
L
~
−
1
\{{y_l}\}_{l=1}^{\widetilde L-1}
{yl}l=1L
−1表示高频疾病,最后一个
y
L
~
y_{\widetilde L}
yL
是包含所有低频疾病的混合标签
4. Triplet Attention Network (TAN)
TAN 由四个可学习组件组成: 图像特征提取器(IFE)、标签嵌入提取器(LEE)、三重注意变换器(TAT)和标签预测分类器(LPC)。
- 图像特征提取器(IFE)将医学图像转换为图像空间特征和类别注意力特征**
- 标签嵌入提取器则为所有疾病标签生成与图像特征维度相同的初始标签嵌入
- 图像空间特征、类别注意特征和初始标签嵌入被送入 TAT
- 在 TAT 中,标签嵌入会得到类别注意特征的强化,然后通过自我注意和交叉注意来模拟全局依赖关系以及与图像空间特征的交互
- LPC 通过更新标签嵌入预测每种疾病的概率
4.1. Image Feature Extractor (IFE)
如上图所示,给定一幅医学图像 𝒙,IFE 通过卷积主干输出其深度特征
F
∈
R
h
×
w
×
d
F\in{R^{h\times w\times d}}
F∈Rh×w×d ,其中
h
,
w
,
d
{h, w, d}
h,w,d分别为深度特征的高度、宽度和通道。然后,两个尾分支根据
F
F
F产生图像空间特征
F
s
F_s
Fs和类别注意力特征
F
a
F_a
Fa。
-
Image spatial features F s F_s Fs
图像空间特征 F s F_s Fs保留了感兴趣的图像信息,包含较少的无关信息,可通过以下方法获得:
其中, w s ∈ R d × d w_s \in{ R^{d\times d}} ws∈Rd×d是点对点投影矩阵,而 R e s h a p e ( ⋅ ) Reshape(·) Reshape(⋅)操作将特征大小从 h × w × d h\times w\times d h×w×d变为 h w × d hw\times d hw×d。在 F s F_s Fs中,每个子特征 f s p ∈ R d f^p_s\in{R^d} fsp∈Rd,其中 p ∈ [ 1 , h w ] p \in{ [1,hw]} p∈[1,hw],可视为原始图像空间中一个空间局部区域的集中 -
Category attention features F a F_a Fa
类别关注特征可以表示图像特征对不同类别的重要性。我们首先使用点对点投影矩阵 w a ∈ R d × L ~ w_a\in{R^{d\times{\widetilde L}}} wa∈Rd×L 生成 F ˙ ∈ R h × w × L ~ \dot{F}\in{R^{h\times\ w\times {\widetilde L}}} F˙∈Rh× w×L ,然后应用定制的类别关注(CA)模块生成 F a F_a Fa:CA 的详细信息可参见算法 1。为确保图像注意力特征的有效性,我们直接使用全局平均池化(GAP)和 F ˙ \dot{F} F˙的sigmoid激活来进行中间标签预测。预测结果通过交叉熵(CE)损失与the ground truth进行测量,该损失将作为整体损失的一部分参与模型优化:
4.2. Label Embedding Extractor (LEE)
- 对于每幅医学图像,通过 LEE 为 L ~ \widetilde L L 疾病标签生成初始标签嵌入 E ∈ R L ~ × d = [ e 1 , e 2 , . . . , e L ~ ] E\in{R^{{\widetilde L}\times d}=[e_1, e_2,..., e_{\widetilde L}]} E∈RL ×d=[e1,e2,...,eL ],其中 e l ∈ R d e_l\in{R^d} el∈Rd是疾病-𝑙的标签嵌入。LEE 由 Pytorch 的自动模块 torch.nn.Embedding 实现。
- 所有标签都被分配了唯一的索引,所有索引都被转换成相应的标签嵌入。初始标签嵌入将与图像空间特征和类别注意力特征一起被发送到后续的 TAT 中,以便自我更新。
4.3. Triplet Attention Transformer (TAT)
- 一个完整的Transformer包含一个编码器模块和一个解码器模块。编解码器模块由多个相同架构的编解码器层组成。
- 每个编码器层包含自注意力层和前馈网络( FFN )
- 每个解码器层包含自注意力层、交叉注意力层和FFN。
- 在自注意力层中,通过三个不同的权重矩阵
w
q
w_q
wq,
w
k
w_k
wk和
w
v
w_v
wv将输入特征Z转化为查询特征Q,关键特征K和值特征V:
然后,每个子特征 q p ∈ Q q_{p}∈Q qp∈Q查询K中的所有子特征来计算注意力分数。最后,将注意力分数归一化并与V中相应的子特征相乘。该过程可以混合成一个单一的函数:
式中:d为特征维数。然而,在交叉注意力层中,三元组(Q、K、V)是由两个不同的输入特征 Z ( 1 ) Z ^{(1)} Z(1)和 Z ( 2 ) Z ^{(2)} Z(2)计算得到的:
下一步的注意力计算如( 6 ) 。与自注意力层使用查询特征来检索自身的关键特征不同,跨注意力层使用查询特征从另一个输入特征中检索关键特征。
TAT包括三种注意类型,即类别注意、自我注意和交叉注意,如图4所示。 - 首先通过类别注意力特征
F
a
F_a
Fa强化标签嵌入E,然后将增强后的标签嵌入与图像空间特征
F
s
F_s
Fs级联,通过编码器模块中的自注意力层来建模全局依赖关系:
4.4. Label Prediction Classifier (LPC)
使用一个独立的FFN来预测每个更新的标签嵌入
e
l
′′
∈
E
′′
e^{′′}_{l}∈E^{′′}
el′′∈E′′为正的概率:
5. Dual-pool contrastive learning
5.1. Dual-pool Contrastive Training (DCT)
学习混合标签的负标签嵌入和正标签嵌入之间的差异更为合理:
- 在抽样小批量中,与特定高频疾病相关的所有正标签嵌入都显示出较高的聚合度。然而,对于合并各种低频疾病的混合标签来说,学习聚合的正标签嵌入比较困难
- 分类器无法在训练过程中学习到未见疾病的高响应
如果样本 𝒙 的多标签ground truth Y = [ y 1 , y 2 , . . . , y L ~ ] Y=[y_1,y_2,...,y_{\widetilde L}] Y=[y1,y2,...,yL ]中的所有疾病标签都为 0,就把这个样本 𝒙 称为负样本,否则就把𝒙 称为正样本。即负样本表示健康样本,正样本表示患病样本。发现来自负样本的所有标签嵌入和来自正样本的大部分标签嵌入都是负标签嵌入。
- 将训练集分为负样本池(只包含负样本)和正样本池(只包含正样本)。在训练阶段,我们分别从两个样本池中随机抽取两个独立的小批量样本,然后将它们送入 TAN,得到它们的更新标签嵌入 E n e g a t i v e ′′ = [ e 1 ′ ′ , e 2 ′ ′ , . . . , e L ~ ′ ′ , ] E^{′′}_{negative}= [e^{''}_{1},e^{''}_{2},...,e^{''}_{{\widetilde L}},] Enegative′′=[e1′′,e2′′,...,eL ′′,]和 E p o s i t i v e ′′ = [ e 1 ′ ′ , e 2 ′ ′ , . . . , e L ~ ′ ′ ] E^{′′}_{positive}= [e^{''}_{1},e^{''}_{2},...,e^{''}_{{\widetilde L}}] Epositive′′=[e1′′,e2′′,...,eL ′′]
- 在高维特征空间中,所有的负标签嵌入应该更靠近,距离正标签嵌入应该更远,来自不同高频疾病的正标签嵌入之间也应该存在差异。
- 提出DCT来优化来自不同疾病标签的负标签嵌入和正标签嵌入的聚类中心,通过学习差异来更好地区分一个标签嵌入是负的还是正的,如下图所示。DCT由池间对比损失和池内对比损失组成。
-
池间对比损失:
度量了两个池之间来自同一疾病标签的标签嵌入的相似性。 -
池内对比损失:
池内对比损失保证了正样本池中来自不同疾病标签的正标签嵌入具有区分性。对于任意两个标签嵌 e i ′ ′ , e j ′ ′ ∈ E p o s i t i v e ′′ e^{''}_{i},e^{''}_{j}∈E^{′′}_{positive} ei′′,ej′′∈Epositive′′,如果它们都是负的标签嵌入,我们将它们拉近,否则我们将它们保持距离
池内对比损失可以写为:
5.2. Overall loss for model training
5.3. Dual-pool Contrastive Inference (DCI)
在模型推断阶段,提出DCI来进一步缓解正标签嵌入的错误分类,提高对未见疾病的检测能力。当输入测试医学图像x时,TAN输出其更新的标签嵌入
e
l
′
′
e^{''}_{l}
el′′和疾病的标签预测图
−
l
-l
−l,其中
l
∈
[
1
,
L
~
]
l∈[1,\widetilde L]
l∈[1,L
]。然后,从负样本池中随机选取m个样本,并将其送入TAN中,得到它们关于disease - l的标签嵌入
{
e
l
1
′
′
,
.
.
.
,
e
l
2
′
′
}
\{e^{''}_{l_{1}},...,e^{''}_{l_{2}}\}
{el1′′,...,el2′′}。
6. Experiments
6.1 Datasets
- ODIR
- 共纳入5000例患者的10000张彩色眼底图像,由不同图像分辨率的相机采集
- ODIR数据集中出现了7种眼病,分别是 “Diabetic Retinopathy (DR)”、“Glaucoma”、“Cataract”、“Age-related Macular Degeneration (AMD)”、“Hypertensive Retinopathy (HR)”、“Myopia” 和"Other Diseases"。ODIR使用了混合标签 “Other Diseases” 来存储除前六种以外的其他眼部疾病
- 将ODIR随机划分为训练集80%,验证集10%,测试集10%
- NIH-ChestXray14
- NIH-ChestXray14是由30805例患者的112120张正位X线图像组成的胸部X线图像数据集。
- 14种胸部疾病: 包括 ‘‘Atelectasis’’ , ‘‘Cardiomegaly’’, ‘‘Effusion’’, ‘‘Infiltration’’, ‘‘Mass’’, ‘‘Nodule’’, ‘‘Pneumonia’’, ‘‘Pneumothorax’’, ‘‘Consolidation’’, ‘‘Edema’’, ‘‘Emphysema’’, ‘‘Fibrosis’’, ‘‘Pleural Thickening’’ and ‘‘Hernia’’
- 按照官方数据拆分,训练集、验证集和测试集分别有78468、11219和22433个样本。
- 为了使NIH-ChestXray14能够适用于任务,重新设计了标签表示。将7种低频疾病( (‘‘Cardiomegaly’’, ‘‘Pneumonia’’, ‘‘Edema’’, ‘‘Emphysema’’, ‘‘Fibrosis’’, ‘‘Pleural-Thickening’’ and ‘‘Hernia’’)合并为一个混合标签"‘Other Diseases",并保留其他疾病的独立标签。疾病数量最终固定为8种。
6.2. Metrics and comparisons
- 基于ODIR和NIH-ChestXray14,比较了TA-DCL与其他先进MLC方法的性能,包括DBFocal、C - Tran、Q2L、CheXGCN、AnaXNet、MCG - Net和DRT。
- 采用了几种广泛使用的MLC评价指标:首先计算每个标签的平均精度( AP ),召回率( AR ),F1值( AF1 ),kappa ( AK ),然后找到它们的无权均值。准确率( ACC )是基于样本计算的,我们只认为如果样本中的所有标签都被正确分类,那么这就是正确的。
6.3. Implementation details
- 实验条件:实验是在两个NVIDIA TITAN Xp GPU的硬件条件和Python和Pytorch的软件条件下构建的。
- 模型设定。任何流行的卷积网络,包括Vgg、Xception、ResNet等,都适合作为图像特征提取器的主干。本文采用ResNet101作为主干网络,在ImageNet数据集上通过预训练初始化权重。所有比较先进的方法共享主干。由于ResNet101的输出维度为2048,将标签嵌入的大小d设置为2048。所有医学图像大小调整为640 × 640作为一致的模型输入。在Transformer编解码器模块中,编解码器层数设置为4。还在TAT中应用了多头模块来提高Transformer的鲁棒性,并将头数设置为8。随机负样本数设置为20。在损失函数中,温度因子τ为1,平衡权重λ为0.7。默认阈值为0.5用于标签分类。
- 优化设置。首先预训练仅有40个epochs的图像特征提取器,然后训练整个TA - DCL共计100 epochs。选择初始学习率为10-4,权重衰减为0.1的Adam优化器。DCT中健康样本和患病样本的batch-size为16。在训练过程中采用随机水平翻转进行数据增强。每一个epoch对模型进行验证,选择验证性能最高的epoch进行模型推断。
6.4. Overall evaluation on ODIR & NIH-ChestXray14
- 总体而言,在合理的实验设置下,TA - DCL始终显示出优于对比方法的结果。
- 在4个标签级指标==(AP、AR、AF1、AK)==中,认为AP和AR最能直观地反映性能的提升。更高的AR得分意味着更多的正标签嵌入被正确分类,从而实现更低的临床风险。较高的AP得分表明所有标签嵌入的总精度提高。
- TA - DCL的AR评分在ODIR和NIH - ChestXray14上比其他最好的结果分别提高了约4.8%和2.6%,AP评分在ODIR和NIH - ChestXray14上也分别提高了约2.5%和1.2%。TA - DCL在两个数据集上也取得了最高的样本分类准确率58.91%和64.82%,超过了其他最好的数据集。
6.5. Evaluation of unseen diseases
- 为了对未知疾病的正标签嵌入进行分类,其他MLC方法通过分类器权重来衡量其响应,并进行低相似度匹配。TA - DCL扩大了负标签嵌入和正标签嵌入之间的差距,并进一步与负标签嵌入进行比较,以减少正标签嵌入的错误分类。
- 理论上,TA - DCL比其他MLC方法具有更高的检测未知疾病的能力。基于去除特定高频疾病的训练集和验证集重新训练所有MLC方法,并观察测试集中这些未被发现的疾病是否会被正确分类到"其他疾病"中。
- 表格显示了所有方法在ODIR和NIH - ChestXray14上对未见疾病的平均评估结果。TA - DCL始终优于其他基于定量指标的MLC方法,表现出对未见疾病更好的适应性。
6.6. Ablation analysis
通过ablative experiments来分析可能影响我们模型性能的重要因素。所有ablation analysis均在ODIR和NIH - ChestXray14的验证集上进行。
6.6.1. Ablative evaluation of parameter 𝜆
方程中的参数λ .式( 15 )旨在控制标签预测损失和双池对比损失之间的平衡,对模型优化有重要影响。通过固定其他模型参数,评估了参数λ的不同取值范围,即λ∈{ 0.1,0.2,…,1.0 }。如下图所示,对于两个数据集,λ的最优值都设置为0.7。
6.6.2. Ablative evaluation of triplet attention
- 三元组注意力结合了类别注意力、自注意力和交叉注意力,从图像特征中学习高质量的标签嵌入。为了证明三元组注意力的优势,将三元组注意力替换为其他流行的注意力,包括位置注意力,通道注意力,自注意力和交叉注意力,同时去除DCT和DCI。
- 在ResNet101中分别加入位置注意力和通道注意力,直接输出多标签分类结果,无需进行标签嵌入学习。我们同样按照C - Tran 和Q2L分别进行自注意力评价和交叉注意力评价。如下表所示,三联体注意始终表现出比其他注意机制更好的量化结果。
6.6.3. Ablative evaluation of DCT & DCI
在模型训练过程中,提出DCT来捕捉标签嵌入之间的差异。DCT由池间对比损失 L i r c L_{irc} Lirc和池内对比损失 L i a c L_{iac} Liac组成:
- L i r c L_{irc} Lirc学习同一疾病标签的正标签嵌入和负标签嵌入之间的差异
-
L
i
a
c
L_{iac}
Liac学习不同疾病标签的正标签嵌入之间的差异。
在模型推断阶段提出了DCI,以进一步减少正标签嵌入的错误分类,提高对未见疾病的检测能力。由表5可知,DCT和DCI的使用提高了量化指标,两者的结合取得了最好的效果。
6.6.4. Ablative evaluation of label representation
低频疾病数量对实验的影响:14种疾病根据样本量从低到高排序,然后选择低频疾病以填充混合标签。令低频病害个数
L
l
f
=
{
1
,
3
,
5
,
7
}
L_{lf}=\{ 1,3,5,7\}
Llf={1,3,5,7},表6记录了低频病害个数的ablative results。当我们将更多的低频疾病合并到混合标签中时,TA - DCL的整体分类性能得到了提升。
6.6.5. Ablative evaluation of random negative samples
- DCI旨在通过与m个随机的负样本进行对比,正确地分类更多的正样本。通过调整m值从0到100在两个公共数据集上计算AR度量。
- 在ODIR和NIH - ChestXray14上,当m < 20和m < 15时,AR值保持持续增长。在这项工作中,对两个公共数据集统一设置m = 20。
6.6.6. Computation cost
提出的方法的计算成本:文章来源:https://www.toymoban.com/news/detail-768851.html
- 在模型推断阶段,TA - DCL包含TAN、LPC和DCI,而DCT不参与模型推断。在m = 20的单幅测试图像上,平均推理时间约为0.82 s。在DCI中选择更多的随机负样本进行对比会增加推断时间。例如,当m = 30时,平均推理时间约为1.07 s。
7. Discussion
与SLC相比,MLC可以使用单一模型检测多种疾病,可以节省更多的时间和资源。然而,通过监督学习为医学图像训练一个通用的MLC模型,==需要预先定义一套完整的标签来覆盖所有疾病,并为每个标签收集足够的样本用于模型训练。==事实上,两者在实际场景中都很难实现。==固定的标签表示限制了模型应用的灵活性,样本不平衡问题导致模型训练有偏。==根据与临床医生的意见交换,临床医生表示他们想要的是自动方法可以决定低频疾病是否存在于医学图像中,而不是修改其特定的疾病标签。因此,设计了一种针对低频疾病的混合标签来解决这些问题。在实际使用中,也可以将看不见的疾病放入混合标签中,提高有监督MLC模型的灵活性。
另一个不容忽视的临床关注点是降低误分类(尤其是紧急疾病)带来的风险。直接降低分类阈值并不是一个稳定的策略。当Transformer应用于医学图像的MLC时,其核心是学习具有判别性的标签嵌入,以提高每个疾病标签的二分类精度。基于广义Transformer的MLC方法在单个编码器模块或解码器模块中建模图像特征与标签嵌入之间的关系。提出的TAT引入类别注意力特征来强化标签嵌入的类别重要性,并通过自注意力和交叉注意力深度挖掘图像特征来更新标签嵌入。三种注意类型的组合表现出更好的标签嵌入学习能力,并得到了我们烧蚀研究的验证。
TA - DCL对看不见的疾病有明显的改善。有监督的MLC方法在训练阶段无法学习到未知疾病的有效信息。因此,在模型推断阶段,训练好的FFN分类器对未知疾病的反应很差,进而无法检测出它们。
因此,我们转而通过DCT来学习所有疾病标签的负标签嵌入和正标签嵌入之间的差异。在模型推断阶段,DCI度量测试标签嵌入与负标签嵌入集合之间的相似度得分,并将相似度得分与LPC的预测得分结合进行最终决策。在这个过程中,模型对检验标签嵌入为负的分类更为严格。差异学习与相似性度量相结合的方法对未见病症的检测更加灵活有效。
尽管在多标签分类任务中有更多的正样本被正确分类,但我们的TA - DCL仍然存在一定的局限性。文章来源地址https://www.toymoban.com/news/detail-768851.html
- 首先,DCI模块需要健康样本对未发现的疾病病例进行对比检测。因此,除了测试样本外,还会随机选取若干个负样本进行测试,导致计算成本的增加。这对于硬件条件来说将是一个额外的负担,因为它需要将健康的样本与训练好的模型一起部署。一个潜在的解决方案是,所有的负样本都可以提前通过TAN模块转换为语义嵌入。我们可以使用训练好的模型来部署语义嵌入,以减轻硬件负担并减少负样本的计算代价。此外,我们还可以对包含多个测试医学图像的测试批进行采样操作,而不是对单个测试医学图像进行采样操作。其次,DCL模块旨在提高正样本的分类性能,负样本更容易被误分类为正样本。在未来的工作中,我们将通过深度挖掘图像信息和优化特征聚类策略,进一步提高模型从医学图像中学习判别性标签嵌入的能力。
- 第三,我们探究了合并低频标签对多标签学习性能的影响,但低频疾病的选择完全依赖于自动化统计。引入专家知识进行低频疾病的选择更具有说服力。在未来的工作中,我们将引入专家知识来提高多标签分类模型的交互性和可解释性。
到了这里,关于【论文笔记】Triplet attention and dual-pool contrastive learning for clinic-driven multi-label medical...的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!