tfp.distributions.Categorical.log_ob在tensorflow图模式下的解决方案/回退值



如果输入的标签超出范围,是否有方法避免tfp.distributions.Categorical.log_prob引发错误?

我将一批样本传递给log_prob方法,其中一些样本的值为n_categories + 1,这是从全零概率分布中采样时得到的回退值。我的probs批中的一些概率分布都是零**。

dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action

在这种情况下,我不在乎我从log_prob得到什么值,因为以后我会屏蔽它而不使用它。不确定fallback值是否可以以某种方式实现。如果没有,是否有任何解决方法可以避免在图形模式下(使用@tf.function(执行时出现错误?

**这是因为我正在用可变长度的序列批次的RNN进行随机解码,这是一个seq到seq的任务。

如果可以屏蔽log_prob,也可以将问题屏蔽为1/n。注意,使用分类的logits参数化并放弃(可能(上游softmax激活在数值上更稳定。

相关内容

  • 没有找到相关文章

最新更新