我正在尝试构建一个MDN来学习P(y | x),其中y和x都具有维度D,其中K分量具有完全(非对角线)协方差。从 NN 隐藏层的输出中,我需要构造分量均值、权重和协方差。对于协方差,我想要一组较低的三角矩阵(即协方差的乔列斯基因子),即 [K, D, D] 张量,所以我可以利用这样一个事实,即对于正定矩阵,你只需要携带矩阵的一个三角形。
目前,参数化均值(locs)、权重(对数)和协方差(尺度)的 NN 如下所示:
def neural_network(X):
# 2 hidden layers with 15 hidden units
net = tf.layers.dense(X, 15, activation=tf.nn.relu)
net = tf.layers.dense(net, 15, activation=tf.nn.relu)
locs = tf.reshape(tf.layers.dense(net, K*D, activation=None), shape=(K, D))
logits = tf.layers.dense(net, K, activation=None)
scales = # some function of tf.layers.dense(net, K*D*(D+1)/2, activation=None) ?
return locs, scales, logits
问题是,对于尺度,将tf.layers.dense(net, K*D*(D-1)/2, activation=None)
转换为K DxD下三角矩阵张量的最有效方法是什么(对角线元素成指数以确保正定性)?
TL;DR:使用tf.contrib.distributions.fill_triangular
假设 X 是D
维的K
个元素的张量,让我们将其定义为占位符。
# batch of D-dimensional inputs
X = tf.placeholder(tf.float64, [None, D])
神经网络的定义就像OP一样。
# 2 hidden layers with 15 hidden units
net = tf.layers.dense(X, 15, activation=tf.nn.relu)
net = tf.layers.dense(net, 15, activation=tf.nn.relu)
多元高斯的均值只是先前隐藏层的线性密集层。输出的形状为(None, D)
,因此无需将尺寸乘以K
和整形。
# Parametrisation of the means
locs = tf.layers.dense(net, D, activation=None)
接下来,我们定义下三角形协方差矩阵。关键是在另一个线性密集层的输出上使用tf.contrib.distributions.fill_triangular。
# Parametrisation of the lower-triangular covariance matrix
covariance_weights = tf.layers.dense(net, D*(D+1)/2, activation=None)
lower_triangle = tf.contrib.distributions.fill_triangular(covariance_weights)
最后一件事:我们需要确保协方差矩阵是正半定的。通过将 softplus 激活函数应用于对角线元素,可以轻松实现。
# Diagonal elements must be positive
diag = tf.matrix_diag_part(lower_triangle)
diag_positive = tf.layers.dense(diag, D, activation=tf.nn.softplus)
covariance_matrix = lower_triangle - tf.matrix_diag(diag) + tf.matrix_diag(diag_positive)
就是这样,我们使用神经网络对多元正态分布进行了参数化。
奖励:可训练的多元正态分布
Tensorflow 概率包具有可训练的多元正态分布,具有较低的三角协方差矩阵:tfp.trainable_distributions.multivariate_normal_tril
它可以按如下方式使用:
mvn = tfp.trainable_distributions.multivariate_normal_tril(net, D)
它输出一个多元正态三角分布,方法与tfp.distributions.multivariateNormalTriL相同,包括mean
、covariance
、sample
等。
我建议使用它而不是构建自己的。