使用 Keras 预测网格结构中的所有可能轨迹



我正在尝试预测2D坐标序列。但我不仅想要最可能的未来路径,还希望所有最可能的路径在网格地图中可视化它。 为此,我有由 40000 个序列组成的编辑数据。每个序列由 10 个 2D 坐标对作为输入和 6 个 2D 坐标对作为标签组成。 所有坐标都在固定值范围内。 预测所有可能路径的第一步是什么?为了获得所有可能的路径,我最终必须应用一个softmax,网格中的每个单元格都是一个类,对吗?但是如何处理数据以反映这种网格状结构呢?有什么想法吗?

softmax激活恐怕不会起作用;如果你有无限数量的组合,甚至是有限数量的组合,这些组合还没有出现在你的数据中,就没有办法把它变成一个多类分类问题(或者如果你这样做,你会失去普遍性)。

我能想到的唯一方法是使用变分编码的循环模型。首先,您有很多带注释的数据,这是个好消息;一个由序列X(10,2,)馈送的循环网络肯定能够预测序列Y(6,2,)。但是,由于您不仅想要一个序列,而且想要所有可能的序列,因此这是不够的。你在这里的隐含假设是,在你的序列后面隐藏着一些概率空间,这会影响它们随着时间的推移如何发挥作用;因此,要正确建模序列,您需要对该潜在概率空间进行建模。变分自动编码器(VAE)就是这样做的;它学习潜在空间,因此在推理过程中,输出预测取决于对该潜在空间的采样。然后,对同一输入的多个预测可以产生不同的输出,这意味着您最终可以对预测进行采样,以根据经验近似潜在输出的分布。

不幸的是,VAEs无法在堆栈溢出的一个段落中真正解释,即使他们可以,我也不会是最有资格尝试它的人。尝试在网上搜索LSTM-VAE,并耐心地武装自己;您可能需要做一些学习,但这绝对值得。研究一下 Pyro 或 Edward 也可能是个好主意,它们是 python 的概率网络库,比 Keras 更适合手头的任务。

相关内容

最新更新