Tensorflow的random.truncated_normal使用相同的种子返回不同的结果



下面几行应该得到相同的结果:

print (tf.random.truncated_normal(shape=[2],seed=1234))
print (tf.random.truncated_normal(shape=[2],seed=1234))

但是我得到了:

tf.Tensor([-0.12297685 -0.76935077], shape=(2,), dtype=float32)
tf.Tensor([0.37034193 1.3367208 ], shape=(2,), dtype=float32)

为什么?

这似乎是有意为之,请参阅此处的文档。特别是"例子"。部分。

你需要的是stateless_truncated_normal:

print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))
print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))

给我

tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)
tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)

注意:种子在这里需要是两个数字,我真的不知道为什么(文档没有说)。

Tensorflow有两种类型的种子:全局种子和操作种子——这也是为什么你需要传递两个数字给stateless_truncated_normal,正如xdurch0在他的回答中描述的那样。Tensorflow将这两个种子组合起来生成一个新的种子。

tf.random.truncated_normal(shape=[2],seed=1234) # global seed #1 & operational 1234 -> Seed A
tf.random.truncated_normal(shape=[2],seed=1234) # global seed #2 & operational 1234 -> Seed B
有很多方法可以解决你的问题。预先设置两次全局种子。工作在@tf.functions这些类型重置全局种子,并有自己的操作计数器。或者用另一个答案中的stateless_truncated_normal

如前所述,在文档中也有描述。

最新更新