【论文精读】TransE 及其实现

这篇具有很好参考价值的文章主要介绍了【论文精读】TransE 及其实现。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

TransE 及其实现

1. What is TransE?

TransE (Translating Embedding), an energy-based model for learning low-dimensional embeddings of entities.

核心思想:将 relationship 视为一个在 embedding space 的 translation。如果 (h, l, t) 存在,那么 h + l ≈ t h + l \approx t h+lt

Motivation:一是在 Knowledge Base 中,层次化的关系是非常常见的,translation 是一种很自然的用来表示它们的变换;二是近期一些从 text 中学习 word embedding 的研究发现,一些不同类型的实体之间的 1-to-1 的 relationship 可以被 model 表示为在 embedding space 中的一种 translation。

2. Learning TransE

TransE 的训练算法如下:

【论文精读】TransE 及其实现

2.1 输入参数

  • training set S S S:用于训练的三元组的集合,entity 的集合为 E E E,rel. 的集合为 L L L
  • margin γ \gamma γ:损失函数中的间隔,这个在原 paper 中描述很模糊
  • 每个 entity 或 rel. 的 embedding dim k k k

2.2 训练过程

初始化:对每一个 entity 和 rel. 的 embedding vector 用 xavier_uniform 分布来初始化,然后对它们实施 L1 or L2 正则化。

loop

  • 在 entity embedding 被更新前进行一次归一化,这是通过人为增加 embedding 的 norm 来防止 loss 在训练过程中极小化。
  • sample 出一个 mini-batch 的正样本集合 S b a t c h S_{batch} Sbatch
  • T b a t c h T_{batch} Tbatch 初始化为空集,它表示本次 loop 用于训练 model 的数据集
  • for ( h , l , t ) ∈ S b a t c h (h,l,t) \in S_{batch} (h,l,t)Sbatch do:
    • 根据 (h, l, t) 构造出一个错误的三元组 ( h ′ , l , t ′ ) (h', l, t') (h,l,t)
    • 将 positive sample ( h , l , t ) (h,l,t) (h,l,t) 和 negative sample ( h ′ , l , t ′ ) (h',l,t') (h,l,t) 加入到 T b a t c h T_{batch} Tbatch
  • 计算 T b a t c h T_{batch} Tbatch 每一对 positive sample 和 negative sample 的 loss,然后累加起来用于更新 embedding matrix。每一对的 loss 计算方式为: l o s s = [ γ + d ( h + l , t ) − d ( h ′ + l , t ′ ) ] + loss = [\gamma + d(h+l,t) - d(h'+l,t')]_+ loss=[γ+d(h+l,t)d(h+l,t)]+

这个过程中,triplet 的 energy 就是指的 d ( h + l , t ) d(h+l,t) d(h+l,t),它衡量了 h + l h+l h+l t t t 的距离,可以采用 L1 或 L2 norm,即 ∣ ∣ h + r − t ∣ ∣ ||h + r - t|| ∣∣h+rt∣∣ 具体计算方式可见代码实现。

loss 的计算中, [ x ] + = max ⁡ ( 0 , x ) [x]_+ = \max(0,x) [x]+=max(0,x)

关于 margin γ \gamma γ 的含义, 它相当于是一个正确 triple 与错误 triple 之前的间隔修正,margin 越大,则两个 triple 之前被修正的间隔就越大,则对于 embedding 的修正就越严格。我们看 l o s s = [ γ + d ( h + l , t ) − d ( h ′ + l , t ′ ) ] + loss = [\gamma + d(h+l,t) - d(h'+l,t')]_+ loss=[γ+d(h+l,t)d(h+l,t)]+,我们希望是 d ( h + l , t ) d(h+l,t) d(h+l,t) 越小越好, d ( h ′ + l , t ′ ) d(h'+l,t') d(h+l,t) 越大越好,假设 d ( h + l , t ) d(h+l,t) d(h+l,t) 处于理想情况下等于 0,那么由于 γ \gamma γ 的存在, d ( h ′ + l , t ′ ) d(h'+l,t') d(h+l,t) 如果不是很大的话,仍然会产生 loss,只有当 d ( h ′ + l , t ′ ) d(h'+l,t') d(h+l,t) 大于 γ \gamma γ 时才会让 loss = 0,所以 γ \gamma γ 越大,对 embedding 的修正就越严格。

错误三元组的构造方法:将 ( h , l , t ) (h,l,t) (h,l,t) 中的头实体、关系和尾实体其中之一随机替换为其他实体或关系来得到。

2.3 评价指标

链接预测是用来预测三元组 (h,r,t) 中缺失实体 h, t 或 r 的任务,对于每一个缺失的实体,模型将被要求用所有的知识图谱中的实体作为候选项进行计算,并进行排名,而不是单纯给出一个最优的预测结果。

  1. Mean rank - 正确三元组在测试样本中的得分排名,越小越好

首先对于每个 testing triple,以预测 tail entity 为例,我们将 ( h , r , t ) (h,r,t) (h,r,t) 中的 t 用 KG 中的每个 entity 来代替,然后通过 f r ( h , t ) f_r(h,t) fr(h,t) 来计算分数,这样就可以得到一系列的分数,然后将这些分数排列。我们知道 f 函数值越小越好,那么在前面的排列中,排地越靠前越好。重点来了,我们去看每个 testing triple 中正确答案(也就是真实的 t)在上述序列中排多少位,比如 t 1 t_1 t1 排 100, t 2 t_2 t2 排 200, t 3 t_3 t3 排 60 …,之后对这些排名求平均,就得到 mean rank 值了。

  1. Hits@10 - 得分排名前 n 名的三元组中,正确三元组的占比,越大越好

还是按照上述进行 f 函数值排列,然后看每个 testing triple 正确答案是否排在序列的前十,如果在的话就计数 +1,最终 (排在前十的个数) / (总个数) 就等于 Hits@10。

在原论文中,由于这个 model 比较老了,其 baseline 也没啥参考性,就不做研究了,具体的实验可参考论文。

3. TransE 优缺点

优点:与以往模型相比,TransE 模型参数较少,计算复杂度低,却能直接建立实体和关系之间的复杂语义联系,在 WordNet 和 Freebase 等 dataset 上较以往模型的 performance 有了显著提升,特别是在大规模稀疏 KG 上,TransE 的性能尤其惊人。

缺点:在处理复杂关系(1-N、N-1 和 N-N)时,性能显著降低,这与 TransE 的模型假设有密切关系。假设有 (美国,总统,奥巴马)和(美国,总统,布什),这里的“总统”关系是典型的 1-N 的复杂关系,如果用 TransE 对其进行学习,则会有:

【论文精读】TransE 及其实现
那么这将会使奥巴马和布什的 vector 变得相同。所以由于这些复杂关系的存在,导致 TransE 学习得到的实体表示区分性较低。

4. TransE 实现

这里选择用 pytorch 来实现 TransE 模型。

4.1 __init__ 函数

其参数有:

  • ent_num:entity 的数量
  • rel_num:relationship 的数量
  • dim:每个 embedding vector 的维度
  • norm:在计算 d ( h + l , t ) d(h+l,t) d(h+l,t) 时是使用 L1 norm 还是 L2 norm,即 d ( h + l , t ) = ∣ ∣ h + l − t ∣ ∣ L 1   o r   L 2 d(h+l,t)=||h+l-t||_{L1 \ or \ L2} d(h+l,t)=∣∣h+ltL1 or L2
  • margin:损失函数中的间隔,是个 hyper-parameter
  • α \alpha α:损失函数计算中的正则化项参数
class TransE(nn.Module):
    def __init__(self, ent_num, rel_num, device, dim=100, norm=1, margin=2.0, alpha=0.01):
        super(TransE, self).__init__()
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.device = device
        self.dim = dim
        self.norm = norm # 使用L1范数还是L2范数
        self.margin = margin
        self.alpha = alpha

        # 初始化实体和关系表示向量
        self.ent_embeddings = nn.Embedding(self.ent_num, self.dim)
        torch.nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
        self.ent_embeddings.weight.data = F.normalize(self.ent_embeddings.weight.data, 2, 1)

        self.rel_embeddings = nn.Embedding(self.rel_num, self.dim)
        torch.nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
        self.rel_embeddings.weight.data = F.normalize(self.rel_embeddings.weight.data, 2, 1)

        # 损失函数
        self.criterion = nn.MarginRankingLoss(margin=self.margin)

初始化 embedding matrix 时,直接用 nn.Embedding 来完成,参数分别是 entity 的数量和每个 embedding vector 的维数,这样得到的就是一个 ent_num * dim 大小的 Embedding Matrix。

torch.nn.init.xavier_uniform_ 是一个服从均匀分布的 Glorot 初始化器,在这里做的就是对 Embedding Matrix 中每个位置填充一个 xavier_uniform 初始化的值,这些值从均匀分布 U ( − a , a ) U(-a,a) U(a,a) 中采样得到,这里的 a a a 是:

a = g a i n × 6 f a n _ i n + f a n _ o u t a = gain \times \sqrt{\frac{6}{fan\_in + fan\_out}} a=gain×fan_in+fan_out6

在这里,对于 Embedding 这样的二维矩阵来说,fan_in 和 fan_out 就是矩阵的长和宽,gain 默认为 1。其完整具体行为可参考 pytorch 初始化器文档。

F.normalize(self.ent_embeddings.weight.data, 2, 1) 这一步就是对 ent_embeddings 的每一个值除以 dim = 1 上的 2 范数值,注意 ent_embeddings.weight.data 的 size 是 (ent_num, embs_dim)。具体来说就是这一步把每行都除以该行下所有元素平方和的开方,也就是 l ← l / ∣ ∣ l ∣ ∣ l \leftarrow l / ||l|| ll/∣∣l∣∣

损失函数这里先跳过,之后计算损失的步骤一同来看。

4.2 从 ent_idx 到 ent_embs

由于 network 的输入是 ent_idx,因此需要将其根据 embedding matrix 转换成 ent_embs。我们通过 get_ent_resps 函数来完成,其实就是个静态查表的操作:

class TransE(nn.Module):
	...
	def get_ent_resps(self, ent_idx): #[batch]
        return self.ent_embeddings(ent_idx) # [batch, emb]

4.3 计算 energy d ( h + l , t ) d(h+l, t) d(h+l,t)

它衡量了 h + l h+l h+l t t t 的距离,可以采用 L1 或 L2 norm 来算,具体采用哪个由 __init__ 函数中的 self.norm 来决定:

class TransE(nn.Module):
	...
	def distance(self, h_idx, r_idx, t_idx):
        h_embs = self.ent_embeddings(h_idx) # [batch, emb]
        r_embs = self.rel_embeddings(r_idx) # [batch, emb]
        t_embs = self.ent_embeddings(t_idx) # [batch, emb]
        scores = h_embs + r_embs - t_embs
		
		# norm 是计算 loss 时的正则化项
        norms = (torch.mean(h_embs.norm(p=self.norm, dim=1) - 1.0)
                 + torch.mean(r_embs ** 2) +
                 torch.mean(t_embs.norm(p=self.norm, dim=1) - 1.0)) / 3

        return scores.norm(p=self.norm, dim=1), norms

4.4 计算 loss

self.criterion 是通过实例化 MarginRankingLoss 得到的,这个类的初始化接收 margin 参数,实例化得到 self.criterion,其计算方式如下:

c r i t e r i o n ( x 1 , x 2 , y ) = max ⁡ ( 0 , − y × ( x 1 − x 2 ) + m a r g i n ) criterion(x_1,x_2,y) = \max(0, -y \times (x_1 - x_2) + margin) criterion(x1,x2,y)=max(0,y×(x1x2)+margin)

借助于此,我们可以实现计算 loss 的代码:

class TransE(nn.Module):
	...
	def loss(self, positive_distances, negative_distances):
        target = torch.tensor([-1], dtype=torch.float, device=self.device)
        return self.criterion(positive_distances, negative_distances, target)

positive_distances 就是 d ( h + l , t ) d(h+l,t) d(h+l,t),negative_distances 就是 d ( h ′ + l , t ′ ) d(h'+l, t') d(h+l,t),target = [-1],代入 criterion 的计算公式就是我们计算 一对正样本和负样本的 loss 了。

4.5 forward

class TransE(nn.Module):
	...
	def forward(self, ph_idx, pr_idx, pt_idx, nh_idx, nr_idx, nt_idx):
        pos_distances, pos_norms = self.scoring(ph_idx, pr_idx, pt_idx)
        neg_distances, neg_norms = self.scoring(nh_idx, nr_idx, nt_idx)

        tmp_loss = self.loss(pos_distances, neg_distances)
        tmp_loss += self.alpha * pos_norms   # 正则化项
        tmp_loss += self.alpha * neg_norms   # 正则化项

        return tmp_loss, pos_distances, neg_distances

以上我们讲完了 TransE 模型的定义,接下来就是讲对 TransE 模型的训练了,只要理解了 TransE 模型的定义,其训练应该不是难事。文章来源地址https://www.toymoban.com/news/detail-415948.html

到了这里,关于【论文精读】TransE 及其实现的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Transformer模型原理—论文精读

    今天来看一下Transformer模型,由Google团队提出,论文名为《Attention Is All You Need》。论文地址。 正如标题所说的,注意力是你所需要的一切,该模型摒弃了传统的RNN和CNN结构,网络结构几乎由Attention机制构成,该论文的亮点在于提出了 Multi-head attention 机制,其又包含了 self-a

    2024年02月08日
    浏览(59)
  • 【论文精读】ESViT

           基于transformer的SSL方法在ImageNet线性检测任务上取得了最先进的性能,其关键原因在于使用了基于对比学习方法训练单尺度Transformer架构。尽管其简单有效,但现有的基于transformer的SSL(自监督学习)方法需要大量的计算资源才能达到SoTA性能。        故认为SSL系统的

    2024年02月20日
    浏览(22)
  • 论文精读--Autoformer

    标题:Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting 作者:Haixu Wu, Jiehui Xu, Jianmin Wang, Mingsheng Long(Tsinghua University) 发表刊物:NeurIPS 2021 论文下载地址:https://arxiv.org/abs/2106.13008 作者代码地址:GitHub - thuml/Autoformer: About Code release for \\\"Autoformer: Decompo

    2024年02月09日
    浏览(23)
  • 【论文精读】BERT

           以往的预训练语言表示应用于下游任务时的策略有基于特征和微调两种。其中基于特征的方法如ELMo使用基于上下文的预训练词嵌入拼接特定于任务的架构;基于微调的方法如GPT使用未标记的文本进行预训练,并针对有监督的下游任务进行微调。        但上述两种策

    2024年02月19日
    浏览(36)
  • 论文精读--MAE

    BERT在Transformer的架构上进行了掩码操作,取得了很好的效果。如果对ViT进行掩码操作呢? 分成patch后灰色表示遮盖住,再将可见的patch输入encoder,把encoder得到的特征拉长放回原本在图片中的位置,最后由decoder去重构图片  图二的图片来自ImageNet,没有经过训练,是验证集。左

    2024年02月21日
    浏览(38)
  • 论文精读之BERT

    目录 1.摘要(Abstract) 2.引言(Introduction): 3.结论(Conlusion): 4.BERT模型算法: 5.总结 与别的文章的区别是什么:BERT是用来设计去训练深的 双向的 表示,使用没有标号的数据,再联合左右的上下文信息。(改进在什么地方) 效果有多好:在11个NLP任务上取得了很好的效果。需要

    2024年02月15日
    浏览(34)
  • BERT 论文精读与理解

    1.论文题目 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 2.论文摘要 本文引入了一种名为 BERT 的新语言表示模型,它代表 Transformers 的双向编码器表示。与最近的语言表示模型(Peters et al., 2018a;Radford et al., 2018)不同,BERT 旨在通过联合调节所有层中的左右上

    2024年02月13日
    浏览(42)
  • 【论文精读】NeRF详解

    最近阅读了开启三维重建新纪元的经典文章《NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis》,接下来会​更新NeRF 系列 的论文精读、代码详解,力求做到全网最细!欢迎大家关注和交流! 论文链接:论文 代码链接:Github (这是官方代码,是tensorflow版本) 文章提出

    2024年02月05日
    浏览(40)
  • ViT 论文逐段精读

    https://www.bilibili.com/video/BV15P4y137jb Vision Transformer 挑战了 CNN 在 CV 中绝对的统治地位。Vision Transformer 得出的结论是如果在足够多的数据上做预训练,在不依赖 CNN 的基础上,直接用自然语言上的 Transformer 也能 CV 问题解决得很好。Transformer 打破了 CV、NLP 之间的壁垒。 先理解题目

    2024年02月05日
    浏览(36)
  • 深度学习论文精读[7]:nnUNet

    相较于常规的自然图像,以UNet为代表的编解码网络在医学图像分割中应用更为广泛。常见的各类医学成像方式,包括计算机断层扫描(Computed Tomography, CT)、核磁共振成像(Magnetic Resonance Imaging, MRI)、超声成像(Ultrasound Imaging)、X光成像(X-ray Imaging)和光学相干断层扫描(

    2024年02月05日
    浏览(98)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包