我正在尝试使用此tfx.components.Transform函数获取"tf.transform编码字典"。
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath(_taxi_transform_module_file),
instance_name="taxi")
context.run(transform)
我需要这样的字典:"您加载的数据的字典({feature_name:feature_value}(。
如上所述的转换给了我一个 TfRecord 文件。如何正确解码?
任何帮助将不胜感激。
import tensorflow_transform as tft
def preprocessing_fn(inputs):
NUMERIC_FEATURE_KEYS = ['PetalLengthCm', 'PetalWidthCm',
'SepalLengthCm', 'SepalWidthCm']
TARGET_FEATURES = "Species"
outputs = inputs.copy()
del outputs['Id']
for key in NUMERIC_FEATURE_KEYS:
outputs[key] = tft.scale_to_0_1(outputs[key])
return outputs
编写这样的模块,我已经为鸢尾花数据集编写了一个模块,对于您的数据集来说很容易理解,您也可以这样做,它将被保存为 tfrecord 数据集