JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

这篇具有很好参考价值的文章主要介绍了JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学,深度学习

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学,深度学习

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学,深度学习

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进“随机状态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)

本文由 mdnice 多平台发布文章来源地址https://www.toymoban.com/news/detail-770333.html

到了这里,关于JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 最简单Anaconda+PyTorch深度学习环境配置教程

    深度学习小白从零开始学习配置环境,记录一下踩过的雷坑,做个学习笔记。 配置了好几次之后总结出来的最简单,试错成本最小的方案,分享给大家~ 安装顺序:Anaconda+CUDA+ CuDnn+Pytorch  Anaconda ,中文 大蟒蛇 ,是一个开源的Python发行版本,其包含了conda、Python等180多个科学

    2024年02月02日
    浏览(60)
  • 数据科学中的Python:NumPy和Pandas入门指南【第121篇—NumPy和Pandas】

    数据科学是当今数字时代中的一个重要领域,而Python是数据科学家们最喜爱的编程语言之一。在这篇博客中,我们将介绍Python中两个强大的库——NumPy和Pandas,它们在数据处理和分析中发挥着重要作用。 NumPy是用于科学计算的基础包,提供了高性能的多维数组对象( numpy.nda

    2024年03月15日
    浏览(87)
  • 深度学习:图像去雨网络实现Pytorch (二)一个简单实用的基准模型(PreNet)实现

            本文参考文献: Progressive Image Deraining Networks: A Better and Simpler Baseline Dongwei Ren1, Wangmeng Zuo 2 , Qinghua Hu ∗ 1 , Pengfei Zhu 1 , and Deyu Meng 31College of Computing and Intelligence, Tianjin University, Tianjin, China 2School of Computer Science and Technology, Harbin Institute of Technology, Harbin, China 3Xi’an

    2023年04月21日
    浏览(32)
  • python库,科学计算与数据可视化基础,知识笔记(numpy+matplotlib)

    这篇主要讲一下数据处理中科学计算部分的知识。 之前有一篇pandas处理数据的。 讲一下这几个库的区别。 Pandas主要用来处理类表格数据(excel,csv),提供了计算接口,可用Numpy或其它方式进行计算。 NumPy 主要用来处理数值数据(尤其是矩阵,向量为核心的),本质上是纯

    2024年02月02日
    浏览(42)
  • 在conda虚拟环境中配置cuda+cudnn+pytorch深度学习环境(新手必看!简单可行!)

    本人最近接触深度学习,想在服务器上配置深度学习的环境,看了很多资料后总结出来了对于新手比较友好的配置流程,创建了一个关于深度学习环境配置的专栏,包括从anaconda到cuda到pytorch的一系列操作,专栏中的另外两篇文章如下,如果有不对的地方欢迎大家批评指正!

    2023年04月15日
    浏览(48)
  • 【深度学习】Pytorch 系列教程(十二):PyTorch数据结构:4、数据集(Dataset)

             目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量(Tensor) 2、张量操作(Tensor Operations) 3、变量(Variable) 4、数据集(Dataset) 随机洗牌           ChatGPT:         PyTorch是一个开源的机器学习框架,广泛应用于深度学习领域。它提供了丰富

    2024年02月07日
    浏览(38)
  • 【Python】数据科学工具(Numpy Pandas np.array() 创建访问数组 向量与矩阵 Series DataFrame)

    1.Numpy numpy是Python中一个非常重要的科学计算库,其最基础的功能就是N维数组对象——ndarray。 1.1 数组的创建 1)np.array() 用 np.array() 函数可以将Python的序列对象(如列表、元组)转换为ndarray数组。 2)arange、linspace、logspace np.arange(start, stop, step) :创建一个一维数组,其中的值

    2024年02月10日
    浏览(36)
  • 深度学习-Pytorch数据集构造和分批加载

    pytorch 目前在深度学习具有重要的地位,比起早先的caffe,tensorflow,keras越来越受到欢迎,其他的深度学习框架越来越显得小众。 数据分析 数据分析-Pandas如何转换产生新列 数据分析-Pandas如何统计数据概况 数据分析-Pandas如何轻松处理时间序列数据 数据分析-Pandas如何选择数据

    2024年01月25日
    浏览(42)
  • Pytorch目标分类深度学习自定义数据集训练

    目录 一,Pytorch简介; 二,环境配置; 三,自定义数据集; 四,模型训练; 五,模型验证;         PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。PyTorch 基于 Python: PyTorch 以 Python 为中心或“pythonic”,旨在深度集成 Python 代码,而不是

    2024年02月07日
    浏览(48)
  • 深度学习基础知识-pytorch数据基本操作

    1.1.1 数据结构 机器学习和神经网络的主要数据结构,例如                 0维:叫标量,代表一个类别,如1.0                 1维:代表一个特征向量。如  [1.0,2,7,3.4]                 2维:就是矩阵,一个样本-特征矩阵,如: [[1.0,2,7,3.4 ]                   

    2024年02月11日
    浏览(38)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包