论文地址:https://www.nature.com/articles/s41551-023-01045-x
代码地址:https://github.com/RL4M/IRENE
基于Transformer的表示学习模型,作为临床诊断辅助工具,以统一的方式处理多模态输入。将图像与文字转化为visual tokens和text tokens,通过一个双向的跨模态注意力机制块共同学习不同信息间的整体特征和其关联性来做出决策。
第一个以统一方式使用人工智能处理多模态信息,在临床上辅助医生进行决策诊断。为后续医学领域人工智能处理多模态信息提供一种新的思路。
Data
胸腔医学中,除了胸部X射线,医生还需要考虑患者的人口统计学信息(如年龄和性别)、主诉(如现病史和既往病史)以及实验室检查报告,以便做出准确的诊断决策。实际上,医生会首先将异常的放射学图像模式与主诉中提到的症状或实验室检查报告中的异常结果相关联。然后,医生依靠他们丰富的领域知识和多年的培训,通过共同解释这些多模态数据来做出最佳诊断。
在医学临床领域,常见的数据类型有三种:
图像(Radiograph)
主诉(Chief complaint):非结构化信息,现病史和既往病史
人口统计信息和实验室检查结果(Demographicsand lab test results):结构化信息,性别年龄等
当前的临床辅助决策系统,常采用非同一的方式。首先将非结构化的主诉转化成结构化的数据,然后将不同模态的数据输入到不同的机器学习模块中,产生特定模态的特征。最后使用融合模块对这些特征进行融合。
但是这样做的一个问题是,特定模态模型的训练与融合过程相分离,导致不能获取不同模态之间的联系与关联。
本文提出的 IRENE 共同学习图像、非结构化主诉和结构化临床信息的整体表示来进行决策。
Network structure
IRENE由嵌入层、两个多模态注意力块、10个自注意力块和一个输出层组成。
- free-form embedding 将非结构化与结构化的文字转化成text tokens, image embedding将图像转化成image tokens
-
bidirectional multimodal blocks 不仅计算同一模态内部的注意力,还计算不同模态之间的注意力
bidirectional 指的是text tokens要与image token做注意力,同时image token也要与text token做注意力 - Self-attention blocks 经过两个双向的多模态注意力块后,连接text tokens与image tokens,然后进行自注意力计算
-
图像:图片经过一个卷积层
-
文字:
主诉(ChiComp):经过bert获得token_id, max_len设为40
实验室检查结果(LabTest):经过bert获得token_id max_len设为92
人口统计信息(Sex、Age):经过一个Liner获得 长度为1
最终得到一个长度为40+92+1+1的一维向量,再送入到嵌入层
3. 双向多模态注意力块
自注意力:输入序列经过一个线性映射,得到K,Q,V (n, d)
①Q与K相乘,计算相似度,得到权重分布 (n, d) * (d, n) = (n,n)
②权重分布经过softmax进行归一化
③权重分布与V相乘,加权求和 (n, n) * (n,d) = (n, d)
经过自注意力机制,可以捕捉到输入序列中不同位置之间的关系和依赖
现在有两个K,Q,V:
text tokens的 KT, QT, VT
image tokens的KI, QI, VI
所以:
QI与KI, VI计算注意力捕捉图像之间的依赖关系, QI与KT, VT计算注意力捕捉图像与文本之间的依赖关系
QT与KT, VT计算注意力捕捉文本之间的依赖关系, QT与KI, VI计算注意力捕捉文本与图像之间的依赖关系
在本文中,取λ为1。在第一层双向多模态注意力层中Xi 和Xt分别取平均送入第二层双向多模态注意力块,
在第二层双向多模态注意力层中Xi 和Xt分别取平均后,进行拼接,送入自注意力块中
实验
对比实验:
- 只使用图像
- 非统一方式早期融合
- 非统一方式晚期融合
- 多模态模型GIT: 在大量的图像-文本对中训练 问题:临床医学数据难获取
- 多模态模型Perceiver: 将不同模态的数据进行拼接作为输入。 问题:某个模态数据较少时,关注度低
消融实验: - 不加入双向多模态注意力
- 加入单向多模态注意力
- 加入6层双向多模态注意力
- 不加入主诉
- 不加入实验室检查结果
- 不加入图像数据
代码解读
整体架构
class IRENE(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(IRENE, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
# 经过一个自定义的Transformer,再经过一个Linear
def forward(self, x, cc=None, lab=None, sex=None, age=None, labels=None):
x, attn_weights = self.transformer(x, cc, lab, sex, age)
logits = self.head(torch.mean(x, dim=1))
if labels is not None:
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.float())
return loss
else:
return logits, attn_weights, torch.mean(x, dim=1)
自定义的Transformes:Embeddings和Encoder组成
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_ids, cc=None, lab=None, sex=None, age=None):
embedding_output, cc, lab, sex, age = self.embeddings(input_ids, cc, lab, sex, age)
text = torch.cat((cc, lab, sex, age), 1)
encoded, attn_weights = self.encoder(embedding_output, text)
return encoded, attn_weights
Encoder文章来源:https://www.toymoban.com/news/detail-732580.html
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for i in range(config.transformer["num_layers"]):
if i < 2: # 两个双向多模态注意力
layer = Block(config, vis, mm=True)
else: # 自注意力
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))
def forward(self, hidden_states, text=None):
attn_weights = []
for (i, layer_block) in enumerate(self.layer):
if i == 2: #在第二个双向多模态注意力块后,拼接img与text,送入自注意力块
hidden_states = torch.cat((hidden_states, text), 1)
hidden_states, weights = layer_block(hidden_states)
elif i < 2: # hidden_states:img
hidden_states, text, weights = layer_block(hidden_states, text)
else:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
双向多模态注意力文章来源地址https://www.toymoban.com/news/detail-732580.html
# img-img
'''
需要计算四个注意力:
text-text text-img img-img img-text
'''
attention_scores_img = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores_img = attention_scores_img / math.sqrt(self.attention_head_size)
attention_probs_img = self.softmax(attention_scores_img)
weights = attention_probs_img if self.vis else None
attention_probs_img = self.attn_dropout(attention_probs_img)
context_layer_img = torch.matmul(attention_probs_img, value_layer_img)
context_layer_img = context_layer_img.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_img.size()[:-2] + (self.all_head_size,)
context_layer_img = context_layer_img.view(*new_context_layer_shape)
# text-text
attention_scores_text = torch.matmul(query_layer_text, key_layer_text.transpose(-1, -2))
attention_scores_text = attention_scores_text / math.sqrt(self.attention_head_size)
attention_probs_text = self.softmax(attention_scores_text)
attention_probs_text = self.attn_dropout_text(attention_probs_text)
context_layer_text = torch.matmul(attention_probs_text, value_layer_text)
context_layer_text = context_layer_text.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_text.size()[:-2] + (self.all_head_size,)
context_layer_text = context_layer_text.view(*new_context_layer_shape)
# img-text
attention_scores_it = torch.matmul(query_layer_img, key_layer_text.transpose(-1, -2))
attention_scores_it = attention_scores_it / math.sqrt(self.attention_head_size)
attention_probs_it = self.softmax(attention_scores_it)
attention_probs_it = self.attn_dropout_it(attention_probs_it)
context_layer_it = torch.matmul(attention_probs_it, value_layer_text)
context_layer_it = context_layer_it.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_it.size()[:-2] + (self.all_head_size,)
context_layer_it = context_layer_it.view(*new_context_layer_shape)
# text-img
attention_scores_ti = torch.matmul(query_layer_text, key_layer_img.transpose(-1, -2))
attention_scores_ti = attention_scores_ti / math.sqrt(self.attention_head_size)
attention_probs_ti = self.softmax(attention_scores_ti)
attention_probs_ti = self.attn_dropout_ti(attention_probs_ti)
context_layer_ti = torch.matmul(attention_probs_ti, value_layer_img)
context_layer_ti = context_layer_ti.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_ti.size()[:-2] + (self.all_head_size,)
context_layer_ti = context_layer_ti.view(*new_context_layer_shape)
# img-img 与 img-text取平均
attention_output_img = self.out((context_layer_img + context_layer_it)/2)
# text-text 与 text-img取平均
attention_output_text = self.out((context_layer_text + context_layer_ti)/2)
attention_output_img = self.proj_dropout(attention_output_img)
attention_output_text = self.proj_dropout_text(attention_output_text)
return attention_output_img, attention_output_text, weights
到了这里,关于好文推荐 A transformer-based representation-learning model with unified processing of multimodal input的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!