如何确定TensorFlow embedding_column的适当维度



我是tensorflow的新手,试图理解embedding_column。它需要一个参数dimension,这对我来说不是完全有意义的。

在这个示例中(来自Google教程),dimension = 8

thal = tf.feature_column.categorical_column_with_vocabulary_list(
      'thal', ['fixed', 'normal', 'reversible'])
thal_embedding = tf.feature_column.embedding_column(thal, dimension=8)

我假设是2^3,因为有3种可能,每一种都可以是"on"。或"off".

但是在文档示例中:

video_id = categorical_column_with_identity(
    key='video_id', num_buckets=1000000, default_value=0)
columns = [embedding_column(video_id, 9),...]

我没有跟踪为什么dimension在这里是9。有人能解释一下规则是什么吗?

您提到的三种开-关可能性是One-Hot Encodings,其中值可以是0或1。嵌入是类似的,但有一个关键区别在于它们可以有任何值。这在类别数量大且单热编码不可行的地方很有用。所以我们可以使用比单热编码更低的维数。嵌入维数是一个需要调整的额外参数

示例:对于3个类别

    One-Hot:
    [ [1,0,0], [0,1,0], [0,0,1] ]
    Same with Embeddings and dimension 2
    [ [-0.426, 0.987], [0.657, 0.222], [0.398, -0.725] ]

相关内容

  • 没有找到相关文章

最新更新