如果输入的标签超出范围,是否有方法避免tfp.distributions.Categorical.log_prob
引发错误?
我将一批样本传递给log_prob
方法,其中一些样本的值为n_categories + 1
,这是从全零概率分布中采样时得到的回退值。我的probs
批中的一些概率分布都是零**。
dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action
在这种情况下,我不在乎我从log_prob
得到什么值,因为以后我会屏蔽它而不使用它。不确定fallback
值是否可以以某种方式实现。如果没有,是否有任何解决方法可以避免在图形模式下(使用@tf.function
(执行时出现错误?
**这是因为我正在用可变长度的序列批次的RNN进行随机解码,这是一个seq到seq的任务。
如果可以屏蔽log_prob,也可以将问题屏蔽为1/n。注意,使用分类的logits参数化并放弃(可能(上游softmax激活在数值上更稳定。