从分布中随机采样的高效算法,同时允许更新



这是我不久前在面试中被问到的问题,我找不到答案。

给定一些样本S1、S2。。。Sn和它们的概率分布(或权重,不管它叫什么)P1,P2。。Pn,设计算法,在考虑其概率的情况下随机选择样本。我得到的解决方案如下:

  1. 构建权重Ci的累积数组,例如

    C0=0;Ci=C[i-1]+Pi。

同时计算T=P1+P2+。。。Pn。需要O(n)时间

  1. 生成一致随机数R=T*随机[0.1]
  2. 使用二进制搜索算法,返回最小i,这样Ci>=R。结果是Si。它需要O(logN)时间

现在实际的问题是:假设我想更改其中一个初始权重Pj。如何在比O(n)更好的时间内做到这一点?其他数据结构也是可以接受的,但随机采样算法不应该变得比O(logN)差。

解决此问题的一种方法是重新思考如何构建包含累积总数的二进制搜索树。与其构建二进制搜索树,不如考虑将每个节点解释如下:

  • 每个节点存储一系列专用于节点本身的值
  • 左子树中的节点表示从概率分布到该范围左侧的采样
  • 右侧子树中的节点表示该范围右侧的概率分布中的采样

例如,假设我们对事件A、B、C、D、E、F和G的权重为3、2、2、1、2

               D
             /   
           B       F
          /      / 
         A   C   E   G

现在,我们用概率来注释树。由于A、C、E和G都是叶子,我们给它们每个的概率质量一个:

               D
             /   
           B       F
          /      / 
         A   C   E   G
         1   1   1   1

现在,看看B的树。B的权重为2,A的权重为3,C的概率为2。如果我们将它们归一化到[0,1)的范围,那么A占概率的3/7,B和C各占2/7s。因此,我们让B的节点说,范围[0,3/7)中的任何东西都到左子树,范围[3/7,5/7)中的所有东西都映射到B,范围[5/7,1)中的一切都映射到右子树:

                   D
                 /   
           B              F
 [0, 3/7) /   [5/7, 1)  / 
         A   C          E   G
         1   1          1   1

类似地,让我们处理F.E具有被选择的权重2,而F和G各自具有被选择概率权重1。因此,这里E的子树占概率质量的1/2,节点F占1/4,G的子树占1/4。这意味着我们可以将概率分配为

                       D
                     /   
           B                        F
 [0, 3/7) /   [5/7, 1)   [0, 1/2) /   [3/4, 1)
         A   C                    E   G
         1   1                    1   1

最后,让我们看看根源。左子树的组合权重为3+2+2=7。右子树的组合权重为2+1+1=4。D本身的重量是2。因此,左子树具有被选取的概率7/13,D具有被选取概率2/13,并且右子树具有被挑选概率4/13。因此,我们可以将概率最终确定为

                       D
           [0, 7/13) /    [9/13, 1)
           B                        F
 [0, 3/7) /   [5/7, 1)   [0, 1/2) /   [3/4, 1)
         A   C                    E   G
         1   1                    1   1

要生成一个随机值,您需要重复以下操作:

  • 从根开始:
    • 选择一个范围为[0,1)的统一随机值
    • 如果它在左子树的范围内,则下降到它中
    • 如果它在右子树的范围内,则下降到该范围内
    • 否则,返回当前节点对应的值

概率本身可以在构建树时递归确定:

  • 对于任何叶节点,左概率和右概率都是0
  • 如果内部节点本身具有权重W,其左树具有总权重WL,并且其右树具有总重量WR,则左概率为(W>L/WR)/(W+WR+WR),右概率为(WR

这种重新表述之所以有用,是因为它为我们提供了一种在每次更新概率的O(logn)时间内更新概率的方法。特别是,让我们考虑一下,如果我们更新某个特定节点的权重,不变量会发生什么变化。为了简单起见,让我们假设节点现在是一个叶子。当我们更新叶节点的权重时,叶节点的概率仍然是正确的,但它上面的节点的概率是不正确的,因为该节点的子树之一的权重已经改变。因此,我们可以(在O(1)时间内)通过使用与上述相同的公式来重新计算父节点的概率。但是,该节点的父节点不再具有正确的值,因为它的一个子树权重发生了变化,所以我们也可以在那里重新计算概率。这个过程一直重复到树的根部,我们每个级别进行O(1)计算,以校正分配给每条边的权重。假设树是平衡的,因此我们必须做O(logn)的总工作来更新一个概率。如果节点不是叶节点,则逻辑是相同的;我们只是从树上的某个地方开始。

简而言之,这就提供了

  • O(n)构造树的时间(使用自下而上的方法)
  • O(logn)生成随机值的时间,以及
  • O(logn)更新任何一个值的时间

希望这能有所帮助!

将搜索存储为平衡的二叉树,而不是数组。树的每个节点都应该存储它所包含的元素的总重量。根据R的值,搜索过程要么返回当前节点,要么搜索左子树或右子树。

当元素的权重改变时,搜索结构的更新就是调整从元素到树根的路径上的权重。

由于树是平衡的,所以搜索和权重更新操作都是O(log N)。

对于那些想要一些代码的人,这里有一个python实现:

import numpy

class DynamicProbDistribution(object):
  """ Given a set of weighted items, randomly samples an item with probability
  proportional to its weight. This class also supports fast modification of the
  distribution, so that changing an item's weight requires O(log N) time. 
  Sampling requires O(log N) time. """
  def __init__(self, weights):
    self.num_weights = len(weights)
    self.weights = numpy.empty((1+len(weights),), 'float32')
    self.weights[0] = 0 # Not necessary but easier to read after printing
    self.weights[1:] = weights
    self.weight_tree = numpy.zeros((1+len(weights),), 'float32')
    self.populate_weight_tree()
  def populate_weight_tree(self):
    """ The value of every node in the weight tree is equal to the sum of all 
    weights in the subtree rooted at that node. """
    i = self.num_weights
    while i > 0:
      weight_sum = self.weights[i]
      twoi = 2*i
      if twoi < self.num_weights:
        weight_sum += self.weight_tree[twoi] + self.weight_tree[twoi+1]
      elif twoi == self.num_weights:
        weight_sum += self.weights[twoi]
      self.weight_tree[i] = weight_sum
      i -= 1
  def set_weight(self, item_idx, weight):
    """ Changes the weight of the given item. """
    i = item_idx + 1
    self.weights[i] = weight
    while i > 0:
      weight_sum = self.weights[i]
      twoi = 2*i
      if twoi < self.num_weights:
        weight_sum += self.weight_tree[twoi] + self.weight_tree[twoi+1]
      elif twoi == self.num_weights:
        weight_sum += self.weights[twoi]
      self.weight_tree[i] = weight_sum
      i /= 2 # Only need to modify the parents of this node
  def sample(self):
    """ Returns an item index sampled from the distribution. """
    i = 1
    while True:
      twoi = 2*i
      if twoi < self.num_weights:
        # Two children
        val = numpy.random.random() * self.weight_tree[i]
        if val < self.weights[i]:
          # all indices are offset by 1 for fast traversal of the
          # internal binary tree
          return i-1
        elif val < self.weights[i] + self.weight_tree[twoi]:
          i = twoi # descend into the subtree
        else:
          i = twoi + 1
      elif twoi == self.num_weights:
        # One child
        val = numpy.random.random() * self.weight_tree[i]
        if val < self.weights[i]:
          return i-1
        else:
          i = twoi
      else:
        # No children
        return i-1

def validate_distribution_results(dpd, weights, samples_per_item=1000):
  import time
  bins = numpy.zeros((len(weights),), 'float32')
  num_samples = samples_per_item * numpy.sum(weights)
  start = time.time()
  for i in xrange(num_samples):
    bins[dpd.sample()] += 1
  duration = time.time() - start
  bins *= numpy.sum(weights)
  bins /= num_samples
  print "Time to make %s samples: %s" % (num_samples, duration)
  # These should be very close to each other
  print "nWeights:n", weights
  print "nBins:n", bins
  sdev_tolerance = 10 # very unlikely to be exceeded
  tolerance = float(sdev_tolerance) / numpy.sqrt(samples_per_item)
  print "nTolerance:n", tolerance
  error = numpy.abs(weights - bins)
  print "nError:n", error
  assert (error < tolerance).all()

#@test
def test_DynamicProbDistribution():
  # First test that the initial distribution generates valid samples.
  weights = [2,5,4, 8,3,6, 6,1,3, 4,7,9]
  dpd = DynamicProbDistribution(weights)
  validate_distribution_results(dpd, weights)
  # Now test that we can change the weights and still sample from the 
  # distribution.
  print "nChanging weights..."
  dpd.set_weight(4, 10)
  weights[4] = 10
  dpd.set_weight(9, 2)
  weights[9] = 2
  dpd.set_weight(5, 4)
  weights[5] = 4
  dpd.set_weight(11, 3)
  weights[11] = 3
  validate_distribution_results(dpd, weights)
  print "nTest passed"

if __name__ == '__main__':
  test_DynamicProbDistribution()

我已经实现了一个与Ken的代码相关的版本,但对于最坏的情况O(logn)操作,它与红/黑树相平衡。这是可用的weightedDict.py在:https://github.com/google/weighted-dict

(我本想把这作为对Ken回答的评论,但我没有这样做的名声!)

最新更新