嗨,我正试图弄清楚我是在TFP形状上犯了错误,还是这是TFP错误。我可以从这个简单的联合分布中进行采样,该分布使用3-二元多元正态均值上的正态先验和半柯西的3个图作为MVN协方差对角线上的先验。
我使用的是tensorflow 2.2.0,这个错误发生在tensorflow概率0.10.1和第二个构建的tensorflow可能性0.12.0-dev20200719上。
此代码应独立运行:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
joint_model = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1., name='z_0'),
tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
loc=z_0[...,tf.newaxis],
scale_diag=lambda_k,
name='z_k'),
])
# These work
joint_model.sample()
joint_model.sample(4)
joint_model.log_prob(joint_model.sample())
# This breaks
joint_model.log_prob(joint_model.sample(4))
这是错误消息:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-41-61bae7ab690a> in <module>
1 joint_model.log_prob(joint_model.sample())
2 # ERROR
----> 3 joint_model.log_prob(joint_model.sample(4))
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution.py in log_prob(self, *args, **kwargs)
479 model_flatten_fn=self._model_flatten,
480 model_unflatten_fn=self._model_unflatten)
--> 481 return self._call_log_prob(value, **unmatched_kwargs)
482
483 # Override the base method to capture *args and **kwargs, so we can
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
944 with self._name_and_control_scope(name, value, kwargs):
945 if hasattr(self, '_log_prob'):
--> 946 return self._log_prob(value, **kwargs)
947 if hasattr(self, '_prob'):
948 return tf.math.log(self._prob(value, **kwargs))
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution.py in _log_prob(self, value)
391 def _log_prob(self, value):
392 xs = self._map_measure_over_dists('log_prob', value)
--> 393 return sum(maybe_check_wont_broadcast(xs, self.validate_args))
394
395 @distribution_util.AppendDocstring(kwargs_dict={
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
982 with ops.name_scope(None, op_name, [x, y]) as name:
983 if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 984 return func(x, y, name=name)
985 elif not isinstance(y, sparse_tensor.SparseTensor):
986 try:
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in _add_dispatch(x, y, name)
1274 return gen_math_ops.add(x, y, name=name)
1275 else:
-> 1276 return gen_math_ops.add_v2(x, y, name=name)
1277
1278
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in add_v2(x, y, name)
478 pass # Add nodes to the TensorFlow graph.
479 except _core._NotOkStatusException as e:
--> 480 _ops.raise_from_not_ok_status(e, name)
481 # Add nodes to the TensorFlow graph.
482 _, _, _op, _outputs = _op_def_library._apply_op_helper(
~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
6651 message = e.message + (" name: " + name if name is not None else "")
6652 # pylint: disable=protected-access
-> 6653 six.raise_from(core._status_to_exception(e.code, message), None)
6654 # pylint: enable=protected-access
6655
~/miniconda3/envs/latent2/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Incompatible shapes: [4] vs. [4,3] [Op:AddV2]
这让我非常疯狂,所以任何帮助都很感激
感谢@Miles Turpin分享解决方案参考。为了社区的利益,我在这里提供由github中的jeffpollock9给出的解决方案(答案部分(。
在联合分布中,batch_shape
和event_shape
之间存在混淆,可以通过用tfd.Independent
包裹半柯西分布来解决
请参阅下面中的工作代码
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
joint_model = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1., name='z_0'),
tfd.Independent(tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'), reinterpreted_batch_ndims=1),
lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
loc=z_0[...,tf.newaxis],
scale_diag=lambda_k,
name='z_k'),
])
print(joint_model)
print(joint_model.log_prob(joint_model.sample(4)))
输出:
tfp.distributions.JointDistributionSequential("JointDistributionSequential", batch_shape=[[], [], []], event_shape=[[], [3], [3]], dtype=[float32, float32, float32])
tf.Tensor([-14.330933 -16.854149 -15.07704 -6.9233823], shape=(4,), dtype=float32)
请参考关于形状和概率分布的推理。