KGAT: Knowledge Graph Attention Network for Recommendation

这篇具有很好参考价值的文章主要介绍了KGAT: Knowledge Graph Attention Network for Recommendation。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

[1905.07854] KGAT: Knowledge Graph Attention Network for Recommendation (arxiv.org)

LunaBlack/KGAT-pytorch (github.com)

目录

1、背景

2、任务定义

3、模型

3.1 Embedding layer

3.2 Attentive Embedding Propagation Layers

3.3 Model Prediction

3.4 Optimization

4、部分代码解读

4.1 数据集

4.2 数据集的处理

4.3 模型

4.4 模型训练


1、背景

CF方法,基于相似用户或者相似商品属性推荐,无法利用属性等各种side information,例如u1和u2相似,则可能会推荐i2

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

基于特征的SL模型,例如FM/NFM/Wide&Deep可以利用side-information,i1和i2有相同属性e1则推荐i2。

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

但是基于特征的SL模型单独的建模每个实例,没有建模实例之间的交互,无法从集体行为提取有用信息。例如u1很难对u1很难建模

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

尽管e1是连接导演和演员字段的桥梁。因此,我们认为这些方法没有充分探索高阶连通性,并且没有触及组合高阶关系

为了解决基于特征的SL模型的局限性,可以将知识图与用户-项目图的混合结构称为协同知识图(CKG),但是也有挑战:1)与目标用户具有高阶关系的节点随着订单规模的增加而急剧增加,这会给模型带来计算过载;2)高阶关系对预测的贡献不平等,这需要模型仔细加权(或选择)它们。

已经有一些基于CKG模型进行推荐的方法:

(1)基于路径的方法

提取携带高阶信息的路径并将其输入到预测模型,但是第一阶段的路径选择对性能影响很大,而且定义有效元路径需要领域知识,工作过量很大。

(2)基于正则化的方法

基于正则化的方法设计额外损失项捕获KG结构,正则化推荐模型。联合训练推荐和KGC两个任务,两个任务间共享item embedding。这些方法不是直接将高阶关系插入为推荐而优化的模型中,而是以隐式的方式对它们进行编码。由于缺乏显式建模,既不能保证捕获远程连接,也不能解释高阶建模的结果。

考虑到上述局限性,作者开发一个能够以高效、显式和端到端方式利用KG中的高阶信息的模型。

2、任务定义

User-Item Bipartite Graph:在推荐场景中,通常有历史用户项目交互(例如,购买和点击),将交互数据表示为用户-物品二部图。

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

Knowledge Graph:除了user-item之间的交互外,还有物品的side-information(例如,物品属性和外部知识)。通常,这些辅助数据由真实世界的实体和额外的知识组成。作者将side-information组成知识图,

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 并有一个实体和item的对齐的集合,

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 Collaborative Knowledge Graph:定义了CKG的概念,它将用户行为和商品知识编码为一个统一的关系图首先将每个用户行为表示为三元组(u, interaction,i),其中y_ui = 1表示为用户u与项目i之间的附加关系interaction。然后基于entity-item对齐集,将user-item图与KG无缝集成为统一图

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

制定本文要解决的推荐任务:

•输入:协作知识图G,包括用户项二部图G1和知识图G2。

•输出:预测函数,预测用户u采用物品i的概率为

3、模型

最主要的是将user-item交互也融入KG中计算

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

3.1 Embedding layer

使用TransR建模,首先将头实体eh和尾实体er利用由特定于关系的投影矩阵Wr投影到关系所在的空间,然后再计算三元组得分(投影的头实体+关系得到的向量,和投影的尾实体向量越相似越好,g越小越好)

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 损失采用对比学习方法KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

3.2 Attentive Embedding Propagation Layers

(1)权重计算

选择tanh作为非线性激活函数。这使得注意力得分依赖于关系r空间中eh和et之间的距离,为更接近的实体传播更多信息,为简单,只使用内积计算

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 接着使用softmax

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

最终的注意力得分能够建议应该给予哪些邻居节点更多的关注来捕获协作信号。在进行前向传播时,注意流会提示需要关注的部分数据,这可以视为推荐背后的解释。

(2)消息传递

为了表征实体h的一阶连通性结构,计算了h的自我网络的线性组合(自我网络,是h为头实体的三元组的集合)

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

(3)聚合

最后一个阶段是将实体自己本身的表征eh和它的自我网络表征e_Nh聚合为实体h的新表征——更正式地说, 

  • GCN Aggregator:将两个表征向量求和,并经过一个非线性激活函数

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

  • GraphSage Aggregator :将两个表征向量拼接,经过一个非线性函数

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

  • Bi-Interaction Aggregator:设计了两种函数,求和,以及两个特征向量元素积,并求和再通过一个非线性函数

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 (4)高阶传播

以上展示的是一阶传播和聚合的例子,很容易可以推广到高阶。

在第l步中,递归地将实体的表示表示为:

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 实体h在自我网络内传播的信息定义如下:

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

3.3 Model Prediction

执行L层后,得到用户节点u的多个表示,即;与项目节点i类似,得到。由于第l层的输出是图1中根于u(或i)的l的树结构深度的消息聚合,因此不同层的输出强调的是不同阶次的连通性信息。因此,采用层聚合机制,将每一步的表示concatence成单个向量

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

这样一来,不仅可以通过进行嵌入传播操作来丰富初始嵌入,还可以通过调整L来控制传播强度。

最后,对用户表征与物品表征进行内积,从而预测其匹配得分

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

3.4 Optimization

使用BPR损失优化推荐模型,它假设观察到的交互,这表明更多的用户偏好,应该被赋予比未观察到的更高的预测值:

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 (u,i)是观察到的真实的交互,(u,j)是未观察到(负样本)交互。

最后的损失函数,联合嵌入损失和推荐系统损失以及正则化

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 在训练时,交替优化KG嵌入损失和CF推荐损失。

4、部分代码解读

4.1 数据集

最后的CKG图由user-item交互二部图以及补充item信息的KG组成。

  • train.txt/test.txt

训练数据集,由user id 和此user交互的itemID list组成。测试集和训练集中出现的交互为positive sample,没有观察到的交互作为negative sample。

  • user_list.txt

由原来的user id,已经映射到CKG dataset中的id组成org_id remap_id

  • item_list.txt

由原来的item id,已经映射到CKG dataset中的id,以及item在freebase中对应的id组成org_id remap_id freebase_id

  • entity_list.txt

表明KG中的实体,由原来的在freebase中的entity id,已经映射到CKG dataset中的id组成org_id remap_id

  • relation_list.txt

表明KG中的relation,由原来的在freebase中的relation id,已经映射到CKG dataset中的id组成org_id remap_id

4.2 数据集的处理

  • 将kg添加逆关系,并对关系重新编号,做法是+2;将user-item交互图融入kg中,将user重新编号user id+实体总数,将user-item编码为0,将item-user编码为1
  • 采样一个batch_size的数据,包含bath_size的user,以及为每一个user采样user-item交互的正样例,负样例
  • 对产生的CKG图{h:(r,t)}进行采样生成负例正例。
loader_base.py

    def load_cf(self, filename):
        """
        函数说明:对user-item交互矩阵进行处理
        Return:
            (user, item) - user和其作用的item
            user_dict - {user-id:[item1,item2,..],}
        """
        user = []
        item = []
        user_dict = dict()

        lines = open(filename, 'r').readlines()
        for l in lines:
            tmp = l.strip()
            inter = [int(i) for i in tmp.split()]

            if len(inter) > 1:
                user_id, item_ids = inter[0], inter[1:]
                item_ids = list(set(item_ids))

                for item_id in item_ids:
                    user.append(user_id)
                    item.append(item_id)
                user_dict[user_id] = item_ids

        user = np.array(user, dtype=np.int32)
        item = np.array(item, dtype=np.int32)
        return (user, item), user_dict

    def statistic_cf(self):
        """
        获取user、item、训练集、测试集总数
        """
        self.n_users = max(max(self.cf_train_data[0]), max(self.cf_test_data[0])) + 1
        self.n_items = max(max(self.cf_train_data[1]), max(self.cf_test_data[1])) + 1
        self.n_cf_train = len(self.cf_train_data[0])
        self.n_cf_test = len(self.cf_test_data[0])


    def load_kg(self, filename):
        """
        读取最后的CKG数据,返回dataframe形式
        """
        kg_data = pd.read_csv(filename, sep=' ', names=['h', 'r', 't'], engine='python')
        kg_data = kg_data.drop_duplicates()
        return kg_data


    def sample_pos_items_for_u(self, user_dict, user_id, n_sample_pos_items):
        """
        对user-item交互正样本进行采样
        """
        pos_items = user_dict[user_id]
        n_pos_items = len(pos_items)

        sample_pos_items = []
        while True:
            if len(sample_pos_items) == n_sample_pos_items:
                break

            pos_item_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
            pos_item_id = pos_items[pos_item_idx]
            if pos_item_id not in sample_pos_items:
                sample_pos_items.append(pos_item_id)
        return sample_pos_items


    def sample_neg_items_for_u(self, user_dict, user_id, n_sample_neg_items):
        """
        为user-item交互采样负样例
        """
        pos_items = user_dict[user_id]

        sample_neg_items = []
        while True:
            if len(sample_neg_items) == n_sample_neg_items:
                break

            neg_item_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
            if neg_item_id not in pos_items and neg_item_id not in sample_neg_items:
                sample_neg_items.append(neg_item_id)
        return sample_neg_items


    def generate_cf_batch(self, user_dict, batch_size):
        """
        采样batch_size的user,并对对这些user采样正样本,负样本
        """
        exist_users = user_dict.keys()
        if batch_size <= len(exist_users):
            batch_user = random.sample(exist_users, batch_size)
        else:
            batch_user = [random.choice(exist_users) for _ in range(batch_size)]

        batch_pos_item, batch_neg_item = [], []
        for u in batch_user:
            # 为每一个采样的user生成一个正样例和一个负样例
            batch_pos_item += self.sample_pos_items_for_u(user_dict, u, 1)
            batch_neg_item += self.sample_neg_items_for_u(user_dict, u, 1)

        batch_user = torch.LongTensor(batch_user)
        batch_pos_item = torch.LongTensor(batch_pos_item)
        batch_neg_item = torch.LongTensor(batch_neg_item)
        return batch_user, batch_pos_item, batch_neg_item


    def sample_pos_triples_for_h(self, kg_dict, head, n_sample_pos_triples):
        """
        为融合user-item交互的CKG图采样正例
        """
        pos_triples = kg_dict[head]
        n_pos_triples = len(pos_triples)

        sample_relations, sample_pos_tails = [], []
        while True:
            if len(sample_relations) == n_sample_pos_triples:
                break

            pos_triple_idx = np.random.randint(low=0, high=n_pos_triples, size=1)[0]
            tail = pos_triples[pos_triple_idx][0]
            relation = pos_triples[pos_triple_idx][1]

            if relation not in sample_relations and tail not in sample_pos_tails:
                sample_relations.append(relation)
                sample_pos_tails.append(tail)
        return sample_relations, sample_pos_tails


    def sample_neg_triples_for_h(self, kg_dict, head, relation, n_sample_neg_triples, highest_neg_idx):
        """
        为融合user-item交互的CKG图采样负例
        """
        pos_triples = kg_dict[head]

        sample_neg_tails = []
        while True:
            if len(sample_neg_tails) == n_sample_neg_triples:
                break

            tail = np.random.randint(low=0, high=highest_neg_idx, size=1)[0]
            if (tail, relation) not in pos_triples and tail not in sample_neg_tails:
                sample_neg_tails.append(tail)
        return sample_neg_tails


    def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx):
        """为训练集CKG中每一个头实体采样一个正例的(r,t),一个负例的t"""
        exist_heads = kg_dict.keys()
        if batch_size <= len(exist_heads):
            batch_head = random.sample(exist_heads, batch_size)
        else:
            batch_head = [random.choice(exist_heads) for _ in range(batch_size)]

        batch_relation, batch_pos_tail, batch_neg_tail = [], [], []
        for h in batch_head:
            relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1)
            batch_relation += relation
            batch_pos_tail += pos_tail

            neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx)
            batch_neg_tail += neg_tail

        batch_head = torch.LongTensor(batch_head)
        batch_relation = torch.LongTensor(batch_relation)
        batch_pos_tail = torch.LongTensor(batch_pos_tail)
        batch_neg_tail = torch.LongTensor(batch_neg_tail)
        return batch_head, batch_relation, batch_pos_tail, batch_neg_tail

loader_kgat.py
    def construct_data(self, kg_data):
        """
        函数说明:创建逆边,并把user-item交互图融入,创建CKG
        """
        # add inverse kg data
        n_relations = max(kg_data['r']) + 1
        inverse_kg_data = kg_data.copy()
        inverse_kg_data = inverse_kg_data.rename({'h': 't', 't': 'h'}, axis='columns')
        inverse_kg_data['r'] += n_relations
        kg_data = pd.concat([kg_data, inverse_kg_data], axis=0, ignore_index=True, sort=False)

        # re-map user id
        kg_data['r'] += 2
        self.n_relations = max(kg_data['r']) + 1
        self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1
        self.n_users_entities = self.n_users + self.n_entities

        # re-map user id = user-item中的id + num_entities
        self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), self.cf_train_data[1].astype(np.int32))
        self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), self.cf_test_data[1].astype(np.int32))

        self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()}
        self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}

        # add interactions to kg data
        # 将user-item交互数据融入kg中user交互item的关系编码为0,item-user交互编码为1
        cf2kg_train_data = pd.DataFrame(np.zeros((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        cf2kg_train_data['h'] = self.cf_train_data[0]
        cf2kg_train_data['t'] = self.cf_train_data[1]

        inverse_cf2kg_train_data = pd.DataFrame(np.ones((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        inverse_cf2kg_train_data['h'] = self.cf_train_data[1]
        inverse_cf2kg_train_data['t'] = self.cf_train_data[0]

        self.kg_train_data = pd.concat([kg_data, cf2kg_train_data, inverse_cf2kg_train_data], ignore_index=True)
        self.n_kg_train = len(self.kg_train_data)

        # construct kg dict
        h_list = []
        t_list = []
        r_list = []

        self.train_kg_dict = collections.defaultdict(list)
        self.train_relation_dict = collections.defaultdict(list)

        for row in self.kg_train_data.iterrows():
            h, r, t = row[1]
            h_list.append(h)
            t_list.append(t)
            r_list.append(r)

            self.train_kg_dict[h].append((t, r))
            self.train_relation_dict[r].append((h, t))

        self.h_list = torch.LongTensor(h_list)
        self.t_list = torch.LongTensor(t_list)
        self.r_list = torch.LongTensor(r_list)

4.3 模型

  • 权重的计算:内积计算相似性,越相似的尾实体,则应该传递更多消息,权重应该更大

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

  • 消息的传递和聚合

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

聚合ego-netework嵌入的加权和以及自身嵌入 

  • 损失函数:包含CF的损失和KGC的损失,以及参数的正则化部分

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

 

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

  •  预测

将多层的消息传递聚合结果拼接起来,然后进行内积运算,得到用户点击某物品的概率

KGAT: Knowledge Graph Attention Network for Recommendation,推荐系统,论文阅读,知识图谱,人工智能,推荐算法

KGAT.py

import torch
import torch.nn as nn
import torch.nn.functional as F


def _L2_loss_mean(x):
    return torch.mean(torch.sum(torch.pow(x, 2), dim=1, keepdim=False) / 2.)


class Aggregator(nn.Module):

    def __init__(self, in_dim, out_dim, dropout, aggregator_type):
        super(Aggregator, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.aggregator_type = aggregator_type

        self.message_dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU()

        if self.aggregator_type == 'gcn':
            self.linear = nn.Linear(self.in_dim, self.out_dim)       # W in Equation (6)
            nn.init.xavier_uniform_(self.linear.weight)

        elif self.aggregator_type == 'graphsage':
            self.linear = nn.Linear(self.in_dim * 2, self.out_dim)   # W in Equation (7)
            nn.init.xavier_uniform_(self.linear.weight)

        elif self.aggregator_type == 'bi-interaction':
            self.linear1 = nn.Linear(self.in_dim, self.out_dim)      # W1 in Equation (8)
            self.linear2 = nn.Linear(self.in_dim, self.out_dim)      # W2 in Equation (8)
            nn.init.xavier_uniform_(self.linear1.weight)
            nn.init.xavier_uniform_(self.linear2.weight)

        else:
            raise NotImplementedError


    def forward(self, ego_embeddings, A_in):
        """
        ego_embeddings:  (n_users + n_entities, in_dim)
        A_in:            (n_users + n_entities, n_users + n_entities), torch.sparse.FloatTensor
        """
        # Equation (3)
        side_embeddings = torch.matmul(A_in, ego_embeddings)

        if self.aggregator_type == 'gcn':
            # Equation (6) & (9)
            embeddings = ego_embeddings + side_embeddings
            embeddings = self.activation(self.linear(embeddings))

        elif self.aggregator_type == 'graphsage':
            # Equation (7) & (9)
            embeddings = torch.cat([ego_embeddings, side_embeddings], dim=1)
            embeddings = self.activation(self.linear(embeddings))

        elif self.aggregator_type == 'bi-interaction':
            # Equation (8) & (9)
            sum_embeddings = self.activation(self.linear1(ego_embeddings + side_embeddings))
            bi_embeddings = self.activation(self.linear2(ego_embeddings * side_embeddings))
            embeddings = bi_embeddings + sum_embeddings

        embeddings = self.message_dropout(embeddings)           # (n_users + n_entities, out_dim)
        return embeddings


class KGAT(nn.Module):

    def __init__(self, args,
                 n_users, n_entities, n_relations, A_in=None,
                 user_pre_embed=None, item_pre_embed=None):

        super(KGAT, self).__init__()
        self.use_pretrain = args.use_pretrain

        self.n_users = n_users
        self.n_entities = n_entities
        self.n_relations = n_relations

        self.embed_dim = args.embed_dim
        self.relation_dim = args.relation_dim

        self.aggregation_type = args.aggregation_type
        self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
        self.mess_dropout = eval(args.mess_dropout)
        self.n_layers = len(eval(args.conv_dim_list))

        self.kg_l2loss_lambda = args.kg_l2loss_lambda
        self.cf_l2loss_lambda = args.cf_l2loss_lambda

        self.entity_user_embed = nn.Embedding(self.n_entities + self.n_users, self.embed_dim)
        self.relation_embed = nn.Embedding(self.n_relations, self.relation_dim)
        self.trans_M = nn.Parameter(torch.Tensor(self.n_relations, self.embed_dim, self.relation_dim))

        if (self.use_pretrain == 1) and (user_pre_embed is not None) and (item_pre_embed is not None):
            other_entity_embed = nn.Parameter(torch.Tensor(self.n_entities - item_pre_embed.shape[0], self.embed_dim))
            nn.init.xavier_uniform_(other_entity_embed)
            entity_user_embed = torch.cat([item_pre_embed, other_entity_embed, user_pre_embed], dim=0)
            self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
        else:
            nn.init.xavier_uniform_(self.entity_user_embed.weight)

        nn.init.xavier_uniform_(self.relation_embed.weight)
        nn.init.xavier_uniform_(self.trans_M)

        self.aggregator_layers = nn.ModuleList()
        for k in range(self.n_layers):
            self.aggregator_layers.append(Aggregator(self.conv_dim_list[k], self.conv_dim_list[k + 1], self.mess_dropout[k], self.aggregation_type))

        # A是邻接矩阵
        self.A_in = nn.Parameter(torch.sparse.FloatTensor(self.n_users + self.n_entities, self.n_users + self.n_entities))
        if A_in is not None:
            self.A_in.data = A_in
        self.A_in.requires_grad = False


    def calc_cf_embeddings(self):
        """
        计算多层的消息传递和聚合
        """
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]

        for idx, layer in enumerate(self.aggregator_layers):
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)

        # Equation (11)
        all_embed = torch.cat(all_embed, dim=1)         # (n_users + n_entities, concat_dim)
        return all_embed


    def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
        """
        user_ids:       (cf_batch_size)
        item_pos_ids:   (cf_batch_size)
        item_neg_ids:   (cf_batch_size)
        """
        all_embed = self.calc_cf_embeddings()                       # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                            # (cf_batch_size, concat_dim)
        item_pos_embed = all_embed[item_pos_ids]                    # (cf_batch_size, concat_dim)
        item_neg_embed = all_embed[item_neg_ids]                    # (cf_batch_size, concat_dim)

        # Equation (12)
        pos_score = torch.sum(user_embed * item_pos_embed, dim=1)   # (cf_batch_size)
        neg_score = torch.sum(user_embed * item_neg_embed, dim=1)   # (cf_batch_size)

        # Equation (13)
        # cf_loss = F.softplus(neg_score - pos_score)
        cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
        cf_loss = torch.mean(cf_loss)

        l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
        loss = cf_loss + self.cf_l2loss_lambda * l2_loss
        return loss


    def calc_kg_loss(self, h, r, pos_t, neg_t):
        """
        h:      (kg_batch_size)
        r:      (kg_batch_size)
        pos_t:  (kg_batch_size)
        neg_t:  (kg_batch_size)
        """
        r_embed = self.relation_embed(r)                                                # (kg_batch_size, relation_dim)
        W_r = self.trans_M[r]                                                           # (kg_batch_size, embed_dim, relation_dim)

        h_embed = self.entity_user_embed(h)                                             # (kg_batch_size, embed_dim)
        pos_t_embed = self.entity_user_embed(pos_t)                                     # (kg_batch_size, embed_dim)
        neg_t_embed = self.entity_user_embed(neg_t)                                     # (kg_batch_size, embed_dim)

        r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1)                       # (kg_batch_size, relation_dim)
        r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)
        r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)

        # Equation (1)
        pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1)     # (kg_batch_size)
        neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1)     # (kg_batch_size)

        # Equation (2)
        # kg_loss = F.softplus(pos_score - neg_score)
        kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
        kg_loss = torch.mean(kg_loss)

        l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
        loss = kg_loss + self.kg_l2loss_lambda * l2_loss
        return loss


    def update_attention_batch(self, h_list, t_list, r_idx):
        """
        更新注意力权重
        """
        r_embed = self.relation_embed.weight[r_idx]
        W_r = self.trans_M[r_idx]

        h_embed = self.entity_user_embed.weight[h_list]
        t_embed = self.entity_user_embed.weight[t_list]

        # Equation (4)
        r_mul_h = torch.matmul(h_embed, W_r)
        r_mul_t = torch.matmul(t_embed, W_r)
        v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
        return v_list


    def update_attention(self, h_list, t_list, r_list, relations):
        device = self.A_in.device

        rows = []
        cols = []
        values = []

        for r_idx in relations:
            index_list = torch.where(r_list == r_idx)
            batch_h_list = h_list[index_list]
            batch_t_list = t_list[index_list]

            batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
            rows.append(batch_h_list)
            cols.append(batch_t_list)
            values.append(batch_v_list)

        rows = torch.cat(rows)
        cols = torch.cat(cols)
        values = torch.cat(values)

        indices = torch.stack([rows, cols])
        shape = self.A_in.shape
        A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

        # Equation (5)
        A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
        self.A_in.data = A_in.to(device)


    def calc_score(self, user_ids, item_ids):
        """
        user_ids:  (n_users)
        item_ids:  (n_items)
        计算user点击item的得分
        """
        all_embed = self.calc_cf_embeddings()           # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                # (n_users, concat_dim)
        item_embed = all_embed[item_ids]                # (n_items, concat_dim)

        # Equation (12)
        cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1))    # (n_users, n_items)
        return cf_score


    def forward(self, *input, mode):
        if mode == 'train_cf':
            return self.calc_cf_loss(*input)
        if mode == 'train_kg':
            return self.calc_kg_loss(*input)
        if mode == 'update_att':
            return self.update_attention(*input)
        if mode == 'predict':
            return self.calc_score(*input)


4.4 模型训练

主要包括交替训练CF与KGC两个任务,并在每次交替训练后更新消息传递的权重。

main_kgat.py文章来源地址https://www.toymoban.com/news/detail-573004.html

def train(args):
    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_save_id = create_log_id(args.save_dir)
    logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
    logging.info(args)

    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)
    if args.use_pretrain == 1:
        user_pre_embed = torch.tensor(data.user_pre_embed)
        item_pre_embed = torch.tensor(data.item_pre_embed)
    else:
        user_pre_embed, item_pre_embed = None, None

    # construct model & optimizer
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    if args.use_pretrain == 2:
        model = load_model(model, args.pretrain_model_path)

    model.to(device)
    logging.info(model)

    cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
    kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # initialize metrics
    best_epoch = -1
    best_recall = 0

    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    epoch_list = []
    metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}

    # train model
    for epoch in range(1, args.n_epoch + 1):
        time0 = time()
        model.train()

        # train cf
        time1 = time()
        cf_total_loss = 0
        n_cf_batch = data.n_cf_train // data.cf_batch_size + 1
        # 交替训练CF与KGC
        for iter in range(1, n_cf_batch + 1):
            time2 = time()
            # 采样一个cf_batch_size的user list,并为user list中的每一个user采样一个正样例和负样例。
            cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = data.generate_cf_batch(data.train_user_dict, data.cf_batch_size)
            cf_batch_user = cf_batch_user.to(device)
            cf_batch_pos_item = cf_batch_pos_item.to(device)
            cf_batch_neg_item = cf_batch_neg_item.to(device)


            cf_batch_loss = model(cf_batch_user, cf_batch_pos_item, cf_batch_neg_item, mode='train_cf')

            if np.isnan(cf_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (CF Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_cf_batch))
                sys.exit()

            cf_batch_loss.backward()
            cf_optimizer.step()
            cf_optimizer.zero_grad()
            cf_total_loss += cf_batch_loss.item()

            if (iter % args.cf_print_every) == 0:
                logging.info('CF Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_cf_batch, time() - time2, cf_batch_loss.item(), cf_total_loss / iter))
        logging.info('CF Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_cf_batch, time() - time1, cf_total_loss / n_cf_batch))

        # train kg
        time3 = time()
        kg_total_loss = 0
        n_kg_batch = data.n_kg_train // data.kg_batch_size + 1

        for iter in range(1, n_kg_batch + 1):
            time4 = time()
            kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = data.generate_kg_batch(data.train_kg_dict, data.kg_batch_size, data.n_users_entities)
            kg_batch_head = kg_batch_head.to(device)
            kg_batch_relation = kg_batch_relation.to(device)
            kg_batch_pos_tail = kg_batch_pos_tail.to(device)
            kg_batch_neg_tail = kg_batch_neg_tail.to(device)

            kg_batch_loss = model(kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail, mode='train_kg')

            if np.isnan(kg_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (KG Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_kg_batch))
                sys.exit()

            kg_batch_loss.backward()
            kg_optimizer.step()
            kg_optimizer.zero_grad()
            kg_total_loss += kg_batch_loss.item()

            if (iter % args.kg_print_every) == 0:
                logging.info('KG Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_kg_batch, time() - time4, kg_batch_loss.item(), kg_total_loss / iter))
        logging.info('KG Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_kg_batch, time() - time3, kg_total_loss / n_kg_batch))
        # 交替训练完一次更新注意力权重
        # update attention
        time5 = time()
        # h_list/t_list/r_list是CKG图中所有的头实体、关系、尾实体列表
        h_list = data.h_list.to(device)
        t_list = data.t_list.to(device)
        r_list = data.r_list.to(device)
        relations = list(data.laplacian_dict.keys())
        model(h_list, t_list, r_list, relations, mode='update_att')
        logging.info('Update Attention: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time5))

        logging.info('CF + KG Training: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time0))

        # evaluate cf
        if (epoch % args.evaluate_every) == 0 or epoch == args.n_epoch:
            time6 = time()
            _, metrics_dict = evaluate(model, data, Ks, device)
            logging.info('CF Evaluation: Epoch {:04d} | Total Time {:.1f}s | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
                epoch, time() - time6, metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

            epoch_list.append(epoch)
            for k in Ks:
                for m in ['precision', 'recall', 'ndcg']:
                    metrics_list[k][m].append(metrics_dict[k][m])
            best_recall, should_stop = early_stopping(metrics_list[k_min]['recall'], args.stopping_steps)

            if should_stop:
                break

            if metrics_list[k_min]['recall'].index(best_recall) == len(epoch_list) - 1:
                save_model(model, args.save_dir, epoch, best_epoch)
                logging.info('Save model on epoch {:04d}!'.format(epoch))
                best_epoch = epoch

    # save metrics
    metrics_df = [epoch_list]
    metrics_cols = ['epoch_idx']
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            metrics_df.append(metrics_list[k][m])
            metrics_cols.append('{}@{}'.format(m, k))
    metrics_df = pd.DataFrame(metrics_df).transpose()
    metrics_df.columns = metrics_cols
    metrics_df.to_csv(args.save_dir + '/metrics.tsv', sep='\t', index=False)

    # print best metrics
    best_metrics = metrics_df.loc[metrics_df['epoch_idx'] == best_epoch].iloc[0].to_dict()
    logging.info('Best CF Evaluation: Epoch {:04d} | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        int(best_metrics['epoch_idx']), best_metrics['precision@{}'.format(k_min)], best_metrics['precision@{}'.format(k_max)], best_metrics['recall@{}'.format(k_min)], best_metrics['recall@{}'.format(k_max)], best_metrics['ndcg@{}'.format(k_min)], best_metrics['ndcg@{}'.format(k_max)]))

到了这里,关于KGAT: Knowledge Graph Attention Network for Recommendation的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • KG-BERT: BERT for Knowledge Graph Completion 2019ACL

    把BERT用在知识图谱补全上 提出KG-BERT模型,在预训练好的BERT基础上继续fine-tuning。 传统的KGC方法一般依赖于KGE,而KGE往往通过将KG中的三元组关系投影到某个表示空间中,然后使用打分函数对三元组的合理性进行评估,在用基于正负样本的对比进行模型的训练,而这个表示空

    2024年02月07日
    浏览(32)
  • 【论文笔记】Knowledge Is Flat: A Seq2Seq Generative Framework for Various Knowledge Graph Completion

    arxiv时间: September 15, 2022 作者单位i: 南洋理工大学 来源: COLING 2022 模型名称: KG-S2S 论文链接: https://arxiv.org/abs/2209.07299 项目链接: https://github.com/chenchens190009/KG-S2S 以往的研究通常将 KGC 模型与特定的图结构紧密结合,这不可避免地会导致两个缺点 特定结构的 KGC 模型互不兼容 现

    2024年01月19日
    浏览(25)
  • 论文阅读《ICDE2023:Relational Message Passing for Fully Inductive Knowledge Graph Completion》

    论文链接 工作简介 在知识图谱补全 (KGC) 中,预测涉及新兴实体和 / 或关系的三元组, 这是在学习 KG 嵌入时看不到的,已成为一个关键挑战。 带有消息传递的子图推理是一个很有前途和流行的解决方案。 最近的一些方法已经取得了很好的性能,但它们 (1) 通常只能预测单独

    2024年02月07日
    浏览(29)
  • [论文阅读]Coordinate Attention for Efficient Mobile Network Design

      最近关于移动网络设计的研究已经证明了通道注意力(例如, the Squeeze-and-Excitation attention)对于提高模型的性能有显著的效果,但它们通常忽略了位置信息,而位置信息对于生成空间选择性注意图非常重要。在本文中,我们提出了一种新的移动网络注意力机制,将位置信息

    2024年02月07日
    浏览(36)
  • MAttNet- Modular Attention Network for Referring Expression Comprehension

    出版年份:2018 出版期刊:CVPR2018 影响因子: 文章作者:Yu Licheng,Lin Zhe,Shen Xiaohui,Yang Jimei,Lu Xin,Bansal Mohit,Berg Tamara L. 研究背景: 最近的大多数研究都将表达式视为一个单一的单元 然而,这些工作大多使用所有特征(目标对象特征、位置特征和上下文特征)的简单串联作为输入,使

    2024年03月19日
    浏览(79)
  • 【论文阅读笔记】PraNet: Parallel Reverse Attention Network for Polyp Segmentation

    PraNet: Parallel Reverse Attention Network for Polyp Segmentation PraNet:用于息肉分割的并行反向注意力网络 2020年发表在MICCAI Paper Code 结肠镜检查是检测结直肠息肉的有效技术,结直肠息肉与结直肠癌高度相关。在临床实践中,从结肠镜图像中分割息肉是非常重要的,因为它为诊断和手术

    2024年01月20日
    浏览(41)
  • 详解KITTI视觉3D检测模型CMKD: Cross-Modality Knowledge Distillation Network for Monocular 3D Object Detection

    本文介绍一篇激光雷达监督视觉传感器的3D检测模型: CMKD ,论文收录于 ECCV2022 。 在本文中,作者提出了用于单目3D检测的 跨模态知识蒸馏 (CMKD) 网络 ,使用激光雷达模型作为教师模型,监督图像模型(图像模型为CaDDN)。 此外,作者通过 从大规模未标注的数据中提取知识

    2024年01月24日
    浏览(33)
  • 论文笔记:Adaptive Graph Spatial-Temporal Transformer Network for Traffic Flow Forecasting

    论文地址 空间图中一个节点对另一个节点的影响可以跨越多个时间步,分别处理空间维度和时间维度数据的方法对直接建模 跨时空效应 可能是无效的。(在图形建模过程中需要考虑这种跨时空效应) 以前的工作通常使用从距离度量或其他地理联系构建的预定图结构,并使用

    2023年04月08日
    浏览(31)
  • 【论文导读】-Vertically Federated Graph Neural Network for Privacy-Preserving Node Classification纵向联邦图神经网络

    原文地址:https://www.ijcai.org/proceedings/2022/0272.pdf Graph Neural Network (GNN) has achieved remarkable progresses in various real-world tasks on graph data, consisting of node features and the adjacent information between different nodes. High-performance GNN models always depend on both rich features and complete edge information in graph. Howeve

    2024年01月23日
    浏览(30)
  • SA-Net:用于医学图像分割的尺度注意网络 A scale-attention network for medical image segmentation

            医学图像的语义分割为后续的图像分析和理解任务提供了重要的基石。随着深度学习方法的快速发展,传统的 U-Net 分割网络已在许多领域得到应用。基于探索性实验,已发现多尺度特征对于医学图像的分割非常重要。在本文中,我们提出了一种尺度注意力深度学

    2024年02月16日
    浏览(34)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包