如何从几种混合物分布中进行一致的采样



我希望能够从多个混合物分布中获取一致的样品。即,例如,我的代码:

import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np
rate = 0.5 
cat = Categorical(probs=[1-rate, rate])
f1 = Mixture(cat=cat, components=[Normal(loc=10., scale=1.), Deterministic(0.)])
f2 = Mixture(cat=cat, components=[Normal(loc=5., scale=1.), Deterministic(0.)])
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)
sess.run([cat.sample(), f1.sample(), f2.sample()])

我得到:

[1, 10.4463625, 0.0]

这不是我想要的,因为它们是独立绘制的,这是有意义的,如果看一下.sample()方法的源代码。

我的问题:如何绘制样品,以便首先评估Categorical变量,并在f1f2之间共享?

没有现有库代码来实现此目的的方法(尽管沿着这些行有一个内部错误(。

目前,您可以创建一个虚拟分布,该分布返回相同的缓存采样张量:

import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np
class HackedCat(tf.distributions.Categorical):
  def __init__(self, *args, **kwargs):
    super(HackedCat, self).__init__(*args, **kwargs)
    self._cached_sample = self.sample(use_cached=False)
  def sample(self, *args, **kwargs):
    # Use cached sample by default or when explicitly asked to
    if 'use_cached' not in kwargs or kwargs['use_cached']:
      return self._cached_sample
    else:
      if 'use_cached' in kwargs:
        del kwargs['use_cached']
      return super(HackedCat, self).sample(*args, **kwargs)
def main():
  rate = 0.5 
  cat = HackedCat(probs=[1-rate, rate])
  f1 = Mixture(cat=cat,
               components=[Normal(loc=10., scale=1.),
                           Deterministic(0.)])
  f2 = Mixture(cat=cat,
               components=[Normal(loc=5., scale=1.),
                           Deterministic(0.)])
  with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print sess.run([cat.sample(), f1.sample(), f2.sample()])
if __name__ == '__main__':
  main()

最新更新