PyTorch Geometric基本教程

这篇具有很好参考价值的文章主要介绍了PyTorch Geometric基本教程。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

PyG官方文档文章来源地址https://www.toymoban.com/news/detail-667285.html


# Install torch geometric
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
!pip install -q torch-geometric

import torch
import networkx as nx
import matplotlib.pyplot as plt

1.内置数据集(以KarateClub为例)

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
# 图的数量
print(f'Number of graphs: {len(dataset)}')
# 每个节点的特征尺寸
print(f'Number of features: {dataset.num_features}')
# 节点的类别数量
print(f'Number of classes: {dataset.num_classes}')
# 获取具体的图
data = dataset[0]
print(data)
print('==============================================================')

# 获取图的属性
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {(2*data.num_edges) / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.has_isolated_nodes()}')
print(f'Contains self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
# 取出的图的数据对象为Data类型,包含以下属性
# 1. edge_index 每条边的两个端点的索引组成的元组
# 2. x 节点特征[节点数量,特征维数]
# 3. y 节点标签(类别),每个节点只分配一个类别
# 4. train_mask 
Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])
print(data)
# 获取所有的边
print(data.edge_idx.T)

2.可视化

def visualize(h, color, epoch=None, loss=None, accuracy=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    
    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:
            plt.xlabel((f'Epoch: {epoch}, Loss: {loss.item():.4f} \n'
                       f'Training Accuracy: {accuracy["train"]*100:.2f}% \n'
                       f' Validation Accuracy: {accuracy["val"]*100:.2f}%'),
                       fontsize=16)
    else:
        # networkx的draw_networkx
        nx.draw_networkx(h, pos=nx.spring_layout(h, seed=42), with_labels=False, node_color=color, cmap="Set2")   
    
    plt.show()
from torch_geometric.utils import to_networkx
# 将Data类型转换成networkx
G = to_networkx(data, to_undirected=True)
# 将图可视化,节点颜色为节点的类型
visualize(G, color=data.y)

3.搭建GNN(以GCN为例)

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        out = self.classifier(h)
        return out, h

model = GCN()
print(model)
# 节点分类
model = GCN()

out, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize(h, color=data.y)

4.在KarateClub数据集上训练

import time
model = GCN()

# 交叉熵损失,Adam优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

def train(data):
    optimizer.zero_grad()
    out, h  = model(data.x, data.edge_index)
    # 只对train_mask的节点计算loss
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    accuracy = {}
    # torch.argmax 取置信度最大的一类
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
    
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
    
    return loss, h, accuracy
for epoch in range(500):
    loss, h, accuracy = train(data)
    if epoch % 10 == 0:
        visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy)
        time.sleep(0.3)

到了这里,关于PyTorch Geometric基本教程的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【强化学习】——Q-learning算法为例入门Pytorch强化学习

    🤵‍♂️ 个人主页:@Lingxw_w的个人主页 ✍🏻作者简介:计算机研究生在读,研究方向复杂网络和数据挖掘,阿里云专家博主,华为云云享专家,CSDN专家博主、人工智能领域优质创作者,安徽省优秀毕业生 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话

    2024年02月10日
    浏览(72)
  • PyTorch翻译官网教程-DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK

    Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation 在本教程中,我们将使用Flask部署PyTorch模型,并开放用于模型推断的REST API。特别是,我们将部署一个预训练的DenseNet 121模型来检测图像。 这是关于在生产环境中部署PyTorch模型的系列教程中的第一篇

    2024年02月16日
    浏览(44)
  • python pytorch教程-带你从入门到实战(代码全部可运行)

    其实这个教程以前博主写过一次,不过,这回再写一次,打算内容写的多一点,由浅入深,然后加入一些实践案例。 下面是我们的内容目录: 1.先从数据类型谈起 1.1 如何生成pytorch的各种数据类型? 1.2 pytorch的各种数据类型有哪些属性? 1.3 pytorch的各种数据类型有哪些函数操

    2024年02月13日
    浏览(45)
  • Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

    本专栏重点介绍强化学习技术的数学原理,并且 采用Pytorch框架对常见的强化学习算法、案例进行实现 ,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。 🚀详情:

    2024年02月04日
    浏览(61)
  • pytorch+Anaconda+python3.10+parcharm+win10安装简化教程

    Pytorch+Anaconda+Python3.10+parcharm+WIN10安装简化教程 1、首先登陆pycharm官网,https://www.jetbrains.com/pycharm/download/ 2、下载community版本 3、下载完成后,双击安装,一直点下一步。 1、首先登录anaconda官网,https://www.anaconda.com/ 2、点击Download下载安装包 3、双击安装包安装,选择Just Me 4、更

    2024年02月02日
    浏览(56)
  • 在anaconda下安装pytorch + python3.8+GPU/CPU版本 详细教程

    没安装Anaconda的同学可以参考以下安装链接: https://blog.csdn.net/qq_45281807/article/details/112442577 按照安装CPU版本和GPU两个版本进行分类,一般运行程序建议使用CPU版本的,安装更方便。 注意!如果切换镜像后当出现下载不了的情况,就先切换默认源,然后再修改另一个可以使用的

    2024年01月19日
    浏览(75)
  • Windows 系统从零配置 Python 环境,安装CUDA、CUDNN、PyTorch 详细教程

    进入anaconda官网:https://www.anaconda.com/ 点击 download 下载文件,我这里是 Anaconda3-2022.10-Windows-x86_64.exe (后续更新版本exe文件会有差别) 下载后打开 .exe 文件下载 anaconda: 选择安装路径(用默认的路径也可以): 这里两个都选: 然后安装就可以了。 打开 cmd,输入 conda(如果是

    2024年02月03日
    浏览(102)
  • 如何用conda安装PyTorch(windows、GPU)最全安装教程(cudatoolkit、python、PyTorch、Anaconda版本对应问题)(完美解决安装CPU而不是GPU的问题)

            安装PyTorch的开发环境:Anaconda+CUDA+cuDNN+PyCharm Community 1.1 版本选择 第一步就是最关键的版本对应问题(这决定你能否成功安装PyTorch,以及能否成功安装GPU版本的关键问题),可以这么说,版本不能对应好,后面有很大的问题,因此,我们要先确定版本的对应关系。(

    2024年02月07日
    浏览(57)
  • PyTorch基本操作练习

    实现了一些PyTorch基本操作,原理可参考《神经网络与深度学习》《动手学深度学习》中的内容。个人练习,切勿与任何作业和考试挂钩。代码运行在Python 3.9.7版本以及Pytorch 1.10版本中。 使用Tensor初始化一个1×3的矩阵M和一个2×1的矩阵N,对两矩阵进行减法操作(三种不同形式

    2024年02月17日
    浏览(37)
  • Pytorch基本使用—激活函数

    激活函数是神经网络中的一种数学函数 ,它被应用于神经元的输出,以决定神经元是否应该被激活并传递信号给下一层。 常见的激活函数包括Sigmoid函数、ReLU函数、Tanh函数等 。 激活函数是神经网络中的一种重要组件,它的作用是引入非线性特性,使得神经网络能够学习和表

    2024年02月15日
    浏览(43)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包