我是tensorflow
的新手,试图理解embedding_column
。它需要一个参数dimension
,这对我来说不是完全有意义的。
在这个示例中(来自Google教程),dimension
= 8
thal = tf.feature_column.categorical_column_with_vocabulary_list(
'thal', ['fixed', 'normal', 'reversible'])
thal_embedding = tf.feature_column.embedding_column(thal, dimension=8)
我假设是2^3,因为有3种可能,每一种都可以是"on"。或"off".
但是在文档示例中:
video_id = categorical_column_with_identity(
key='video_id', num_buckets=1000000, default_value=0)
columns = [embedding_column(video_id, 9),...]
我没有跟踪为什么dimension
在这里是9。有人能解释一下规则是什么吗?
您提到的三种开-关可能性是One-Hot Encodings,其中值可以是0或1。嵌入是类似的,但有一个关键区别在于它们可以有任何值。这在类别数量大且单热编码不可行的地方很有用。所以我们可以使用比单热编码更低的维数。嵌入维数是一个需要调整的额外参数
示例:对于3个类别
One-Hot:
[ [1,0,0], [0,1,0], [0,0,1] ]
Same with Embeddings and dimension 2
[ [-0.426, 0.987], [0.657, 0.222], [0.398, -0.725] ]