PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)

这篇具有很好参考价值的文章主要介绍了PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

model.py

import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
class gat_cls(nn.Module):
    def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):
        super(gat_cls,self).__init__()
        self.conv1 = GATConv(in_dim,hid_dim)
        self.conv2 = GATConv(hid_dim,hid_dim)
        self.fc = nn.Linear(hid_dim,out_dim)
        self.relu  = nn.ReLU()
        self.dropout_size = dropout_size
    def forward(self,x,edge_index):
        x = self.conv1(x,edge_index)
        x = F.dropout(x,p=self.dropout_size,training=self.training)
        x = self.relu(x)
        x = self.conv2(x,edge_index)
        x = self.relu(x)
        x = self.fc(x)
        return x

main.py

import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gat_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7


net = gat_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    out = net(cora_data.x,cora_data.edge_index)
    optimizer.zero_grad()
    loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])
    loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])
    loss_train.backward()
    print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))
    optimizer.step()

net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

运行图

PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类),图神经网路学习记录,Pytorch学习记录,分类,深度学习,pytorch,python,机器学习,人工智能文章来源地址https://www.toymoban.com/news/detail-731224.html

到了这里,关于PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyG基于Node2Vec实现节点分类及其可视化

    大家好,我是阿光。 本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。 正在更新中~ ✨ 🚨 我的项目环境: 平台:Windows10 语言环

    2024年02月02日
    浏览(55)
  • 使用PyG(PyTorch Geometric)实现基于图卷积神经网络(GCN)的节点分类任务

    PyG(PyTorch Geometric)是一个基于PyTorch的库,可以轻松编写和训练图神经网络(GNN),用于与结构化数据相关的广泛应用。 它包括从各种已发表的论文中对图和其他不规则结构进行深度学习的各种方法,也称为几何深度学习。此外,它还包括易于使用的迷你批处理加载程序,用

    2023年04月20日
    浏览(45)
  • 【深度学习&图神经网络】Node2Vec +GAT 完成 节点分类任务(含代码) | 附:其它生成节点特征向量的算法:DeepWalk、LINE(具体实现细节)、SDNE、MMDW

      “我从来没有在哪次分离中流过眼泪,因为我觉得,与还健在的人的离别是世界上第二浪漫的事,因为我们从此离别以后 每一次相遇都是重逢,而重逢是世界上第一浪漫的事情。”     🎯作者主页: 追光者♂🔥          🌸个人简介:   💖[1] 计算机专业硕士研究生

    2024年02月07日
    浏览(58)
  • 图神经网络:(图的分类)在MUTAG数据集上动手实现图神经网络

    文章说明: 1)参考资料:PYG的文档。文档超链。 2)博主水平不高,如有错误,还望批评指正。 3)我在百度网盘上传这篇文章的jupyter notebook以及有关文献。提取码8848。 MUTAG数据集是一个分子图形数据集。每个分子包含一个二元标签表示该分子是否为一种类固醇化合物。我们的

    2024年02月05日
    浏览(34)
  • 深度学习推荐系统(八)AFM模型及其在Criteo数据集上的应用

    沿着特征工程自动化的思路,深度学习模型从 PNN ⼀路⾛来,经过了Wide&Deep、Deep&Cross、FNN、DeepFM、NFM等模型,进⾏了大量的、基于不同特征互操作思路的尝试。 但特征工程的思路走到这里几乎已经穷尽了可能的尝试,模型进⼀步提升的空间非常小,这也是这类模型的局限

    2024年02月09日
    浏览(52)
  • 深度学习推荐系统(二)Deep Crossing及其在Criteo数据集上的应用

    在2016年, 随着微软的Deep Crossing, 谷歌的WideDeep以及FNN、PNN等一大批优秀的深度学习模型被提出, 推荐系统全面进入了深度学习时代, 时至今日, 依然是主流。 推荐模型主要有下面两个进展: 与传统的机器学习模型相比, 深度学习模型的表达能力更强, 能够挖掘更多数据

    2024年02月10日
    浏览(44)
  • 深度学习推荐系统(三)NeuralCF及其在ml-1m电影数据集上的应用

    在2016年, 随着微软的Deep Crossing, 谷歌的WideDeep以及FNN、PNN等一大批优秀的深度学习模型被提出, 推荐系统全面进入了深度学习时代, 时至今日, 依然是主流。 推荐模型主要有下面两个进展: 与传统的机器学习模型相比, 深度学习模型的表达能力更强, 能够挖掘更多数据

    2024年02月10日
    浏览(60)
  • 深度学习推荐系统(五)Deep&Crossing模型及其在Criteo数据集上的应用

    在2016年, 随着微软的Deep Crossing, 谷歌的WideDeep以及FNN、PNN等一大批优秀的深度学习模型被提出, 推荐系统全面进入了深度学习时代, 时至今日, 依然是主流。 推荐模型主要有下面两个进展: 与传统的机器学习模型相比, 深度学习模型的表达能力更强, 能够挖掘更多数据

    2024年02月09日
    浏览(50)
  • 深度学习推荐系统(四)Wide&Deep模型及其在Criteo数据集上的应用

    在2016年, 随着微软的Deep Crossing, 谷歌的WideDeep以及FNN、PNN等一大批优秀的深度学习模型被提出, 推荐系统全面进入了深度学习时代, 时至今日, 依然是主流。 推荐模型主要有下面两个进展: 与传统的机器学习模型相比, 深度学习模型的表达能力更强, 能够挖掘更多数据

    2024年02月09日
    浏览(46)
  • 经典神经网络(7)DenseNet及其在Fashion-MNIST数据集上的应用

    DenseNet 不是通过更深或者更宽的结构,而是通过特征重用来提升网络的学习能力。 ResNet 的思想是:创建从“靠近输入的层” 到 “靠近输出的层” 的直连。而 DenseNet 做得更为彻底:将所有层以前馈的形式相连,这种网络因此称作 DenseNet 。 DenseNet 具有以下的优点: 缓解梯度

    2024年02月12日
    浏览(40)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包