Haiku 遵循 JAX 的设计,生成的随机数是两个元素组成的列表。其中第一个元素是用于生成伪随机数的状态,第二个元素是用于分发密钥的子键。两个元素分别用于状态和子键,确保在分布式计算或并行计算中,多个随机数生成器的状态可以在一定程度上相互影响,从而提高随机性。
# 在实际使用中,你通常只需要整个列表或其中的一个元素即可。
rng_key[0]:状态;rng_key[1]:子键
Random Numbers
PRNGSequence(key_or_seed) |
Iterator of JAX random keys. |
next_rng_key() |
Returns a unique JAX random key split from the current global key. |
next_rng_keys(num) |
Returns one or more JAX random keys split from the current global key. |
maybe_next_rng_key() |
next_rng_key() if random numbers are available, else |
reserve_rng_keys(num) |
Pre-allocate some number of JAX RNG keys. |
with_rng(key) |
Provides a new sequence for next_rng_key() to draw from. |
maybe_get_rng_sequence_state() |
Returns the internal state of the PRNG sequence. |
replace_rng_sequence_state(state) |
Replaces the internal state of the PRNG sequence with the given state. |
import haiku as hk
import jax
### 1. class haiku.PRNGSequence(key_or_seed)
# Iterator of JAX random keys
seq = hk.PRNGSequence(42) # OR pass a jax.random.PRNGKey
print(seq)
print(type(seq)) # class 'haiku._src.base.PRNGSequence
key1 = next(seq)
print(key1)
key2 = next(seq)
print(key2)
assert key1 is not key2
### 2. hk.next_rng_key()
# 获取下一个随机数生成器密钥的函数。
key = hk.next_rng_key()
_ = jax.random.uniform(key, [])
### 3. haiku.next_rng_keys(num)
# returns one or more JAX random keys split from the current global key.
k1, k2 = hk.next_rng_keys(2)
assert (k1 != k2).all()
a = jax.random.uniform(k1, [])
b = jax.random.uniform(k2, [])
assert a != b
### 4. haiku.reserve_rng_keys(num)
# 预留一定数量的随机数生成器密钥,并返回一个包含这些密钥的列表。
hk.reserve_rng_keys(2) # Pre-allocate 2 keys for us to consume.
_ = hk.next_rng_key() # Takes the first pre-allocated key.
_ = hk.next_rng_key() # Takes the second pre-allocated key.
_ = hk.next_rng_key() # Splits a new key.
### 5. haiku.with_rng(key)
# 在指定的上下文中使用给定的随机数生成器密钥。
# 通过 with hk.with_rng(rng_key):,我们创建了一个上下文,
# 确保在这个上下文中执行的随机操作使用了相同的随机数生成器密钥,从而使得两次调用的结果是可重复的。
with hk.with_rng(jax.random.PRNGKey(428)):
s = jax.random.uniform(hk.next_rng_key(), ())
print("{:.1f}".format(s))
### 注:2,3,4,5 代码需在 hk.transform后的模块中使用
# hk.transform 是 Haiku 中的一个函数,用于将一个普通的 Python 函数(模块定义)转换为 Haiku 模块。通过转换,Haiku 将能够管理模块的参数、初始化和应用。
参考:文章来源:https://www.toymoban.com/news/detail-808919.html
https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=random#random-numbers文章来源地址https://www.toymoban.com/news/detail-808919.html
到了这里,关于haiku生成随机数的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!