我正在寻找一种机制来保存tf.data.Dataset.shuffle
使用的随机状态。就上下文而言,我希望能够在重新开始时重现训练结果。
我有一个解决方案(如下所示(,但它不是特别优雅,而且我很有信心batch
/unbatch
会导致性能问题。使用Dataset.shuffle
有等效的方法吗?
import tensorflow as tf
import numpy as np
class Shuffler(tf.Module):
def __init__(self, buffer_size: int, seed: int = 0):
self._buffer_size = buffer_size
self._seed = seed
self._rng = tf.random.Generator.from_seed(seed)
def __call__(self, dataset: tf.data.Dataset):
def map_fn(*args):
vals = self._rng.uniform((self._buffer_size,))
i = tf.argsort(vals)
if len(args) == 1:
(args,) = args
return tf.nest.map_structure(lambda x: tf.gather(x, i), args)
return dataset.batch(self._buffer_size).map(map_fn).unbatch()
def as_list(ds: tf.data.Dataset):
return [x.numpy() for x in ds]
shuffler = Shuffler(5)
chkpt = tf.train.Checkpoint(shuffler=shuffler)
p0 = chkpt.save("/tmp/chkpt-0")
ds = tf.data.Dataset.range(5).apply(shuffler)
expected0 = as_list(ds)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(ds)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)
chkpt.restore(p0)
np.testing.assert_equal(as_list(ds), expected0)
np.testing.assert_equal(as_list(ds), expected1)
# mangle state by iterating over it again
as_list(ds)
# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(ds), expected1)
print("Passed!")
发现状态已经在迭代器中进行了管理。
import tensorflow as tf
import numpy as np
def as_list(it: tf.data.Iterator, length: int = 5):
return [it.next().numpy() for _ in range(length)]
ds = tf.data.Dataset.range(5).shuffle(5, seed=0).repeat()
it = iter(ds)
chkpt = tf.train.Checkpoint(it=it)
p0 = chkpt.save("/tmp/chkpt-0")
expected0 = as_list(it)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(it)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)
chkpt.restore(p0)
np.testing.assert_equal(as_list(it), expected0)
np.testing.assert_equal(as_list(it), expected1)
# mangle state by iterating over it again
as_list(it)
# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(it), expected1)
print("Passed!")