如何"one hot encode"张量流数据集?



此处新建。。。我加载TF数据集如下:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)

该数据集包含一个";字符串列";具有一些值,并且我想";一个热的";对它们进行编码。如果我有索引和深度(到目前为止我只有一个String值(,我可以在extract_fn记录中逐个执行此操作。然而,有没有一个TF函数可以为我做到这一点?即

  • 计算不同值的数量
  • 将每个值映射到索引
  • 为其创建一个热编码列

我认为这符合您的要求:

import tensorflow as tf
def one_hot_any(a):
# Save original shape
s = tf.shape(a)
# Find unique values
values, idx = tf.unique(tf.reshape(a, [-1]))
# One-hot encoding
n = tf.size(values)
a_1h_flat = tf.one_hot(idx, n)
# Reshape to original shape
a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
return a_1h, values
# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
print(*sess.run([x_1h, x_vals]), sep='n')

输出:

[[[1. 0. 0. 0.]
[0. 1. 0. 0.]]
[[1. 0. 0. 0.]
[0. 0. 1. 0.]]
[[0. 0. 0. 1.]
[0. 0. 1. 0.]]
[[0. 1. 0. 0.]
[0. 0. 1. 0.]]]
[b'a' b'b' b'd' b'c']

然而,问题是不同的输入会产生不一致的输出,具有不同的值顺序,甚至不同的一个热点深度,所以我不确定它是否真的有用。

相关内容

最新更新