Python Tensorflow数据集过滤器集.issubset()



我有一个tensorflow数据集:

def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))

seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)

我想用下面的Python函数过滤它:

def python_filter(x):
x = set(x)
x = x.issubset({"A", "B", "C", "D"})
return x

不幸的是,用@tf.function装饰不起作用。你们有巫师能帮我吗?这是我迄今为止所拥有的。

def filter(x):
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
# tensorflow function for x.issubset({"A", "B", "C", "D"})
return x
ds = ds.filter(filter)

您可以使用tf.lookup.StaticHashTabletf.cond来解决您想要的问题:

import tensorflow as tf
import numpy as np
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))

seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=-1)
def filter(x):
subset = tf.constant(["A", "B", "C", "D"])
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))
ds = ds.map(filter)
for x in ds.take(5):
print(x)

tf.lookup.StaticHashTable只是将所有字母映射为整数值,这样比较起来更容易。

最新更新