重要性加权自动编码器的性能比VAE差



我一直在caltech剪影数据集上实现VAE和IWAE模型,但遇到了一个问题,即VAE以适度的优势优于IWAE(VAE测试LL~120,IWAE测试~133!(。根据这里产生的理论和实验,我认为情况不应该是这样。

我希望有人能在我的实施过程中发现一些问题,这就是为什么会出现这种情况。

我用来近似qp的网络与上面论文附录中详细描述的网络相同。模型的计算部分如下:

data_k_vec = data.repeat_interleave(K,0) # Generate K samples (in my case K=50 is producing this behavior)
mu, log_std = model.encode(data_k_vec)
z = model.reparameterize(mu, log_std) # z = mu + torch.exp(log_std)*epsilon (epsilon ~ N(0,1))
decoded = model.decode(z) # this is the sigmoid output of the model
log_prior_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))
log_q_z = compute_log_probability_gaussian(z, mu, log_std) # Definitions below
log_p_x = compute_log_probability_bernoulli(decoded,data_k_vec) 
if model_type == 'iwae':
log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, K)
elif model_type =='vae':
log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, 1)*1/K
log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
ws_matrix = torch.exp(log_w_minus_max)
ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)
ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)
loss = -torch.sum(ws_sum_per_datapoint) # value of loss that gets returned to training function. loss.backward() will get called on this value

以下是似然函数。为了在训练时不受伤,我不得不对伯努利LL大惊小怪

def compute_log_probability_gaussian(obs, mu, logstd, axis=1):
return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi)) 
def compute_log_probability_bernoulli(theta, obs, axis=1): # Add 1e-18 to avoid nan appearances in training
return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)

在这段代码中,使用了一个"快捷方式",即在model_type=='iwae'的情况下,每行K=50个样本计算逐行重要性权重,而在model_type=='vae'的情况下则为每行剩余的单个值计算重要性权重,因此最终只计算权重1。也许这就是问题所在?

任何和所有的帮助都是巨大的——我原以为解决nan问题会让我永远摆脱困境,但现在我遇到了这个新问题。

编辑:应补充的是,培训计划与上述文件中的培训计划相同。也就是说,对于i=0....7轮中的每一轮,以1e-4 * 10**(-i/7)的学习率训练2**i历元

K-样本重要性加权ELBO是

$$\textrm{IW-ELBO}(x,K(=\log\sum_{K=1}^K\frac{p(x\vert z_K(p(z_K

对于IWAE,每个数据点x都有K个样本,因此您希望通过摊销推理网络获得相同的潜在统计数据mu_z, Sigma_z,但每个x要采样多个zK次。

因此,计算data_k_vec = data.repeat_interleave(K,0)的正向通过在计算上是浪费的,你应该为每个原始数据点计算一次正向通过,然后重复推理网络输出的统计数据进行采样:

mu = torch.repeat_interleave(mu,K,0)
log_std = torch.repeat_interleave(log_std,K,0)

然后对CCD_ 15进行采样。现在重复您的数据点data_k_vec = data.repeat_interleave(K,0),并使用得到的张量来有效地评估每个重要样本z_k的条件p(x |z_k)

请注意,在计算IW-ELBO以获得数值稳定性时,您可能还需要使用logsumexp运算。我不太清楚你帖子中的log_w_matrix计算是怎么回事,但我会这么做:

log_pz = ...
log_qzCx = ....
log_pxCz = ...
log_iw = log_pxCz + log_pz - log_qzCx
log_iw = log_iw.reshape(-1, K)
iwelbo = torch.logsumexp(log_iw, dim=1) - np.log(K)

编辑:事实上,经过思考并使用得分函数恒等式,你可以将IWAE梯度解释为标准单样本梯度的重要性加权估计,因此OP中计算重要性权重的方法是等效的(如果有点浪费的话(,前提是你在归一化的重要性权重周围放置stop_gradient算子,您称之为w_norm。所以我的主要问题是缺少这个stop_gradient算子。

最新更新