如何在Tensorflow中实现集合查找



在tensorflow数据集的预处理过程中,我需要检查某个值是否包含在不可变的集合中。如果不是,我需要用默认值替换它。本质上,它是关于审查/替换某些异常值

在python中,我会做这样的事情:

def map_id (value):
s = frozenset([1,2,3])
if value in s:
return value
else:
return 0 # default for all outliers

这个map_id函数将像这个一样被调用

def preprocess(item):
return (map_id(item["investment_id"]), item["features"]), item["target"]

preprocess函数将像这个一样被调用

def make_dataset(file_paths, batch_size=4096, mode="train"):
ds = tf.data.TFRecordDataset(file_paths)
ds = ds.map(decode_function)
ds = ds.map(preprocess)
if mode == "train":
ds = ds.shuffle(batch_size * 4)
ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
return ds

如何在Tensorflow 2.x中编写这个map_id函数?

我不确定您的数据是什么样子的,但您应该能够使用一个简单的StaticHashTable作为您用例的Set替代方案,因为它将以图形模式运行:

import tensorflow as tf
data = {
"investment_id": [1, 2, 3, 4, 5], 
"features": [12, 912, 28, 90, 17],
"target": [1, 0, 1, 1, 1]
}
keys_tensor = tf.constant([1, 2, 3])
vals_tensor = tf.constant([1, 2, 3])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=0)

ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.map(lambda item: (table.lookup(item['investment_id']), item['features'], item['target']))
for d in ds:
print(d)
(<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=12>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=912>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=28>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=90>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=17>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)

最新更新