Spark Dataframe to Tensorflow Dataset (tf.data API)



我正在尝试将 aspark 数据帧转换为张量流记录,然后从 tensorflow 调用它作为数据集,以便为我的模型获取输入。这是行不通的。

我的尝试如下:

1( 用库火花张量流连接器的罐子获得火花会议:

spark = SparkSession.builder.config(conf=SparkConf().set("spark.jars", "path/to/spark-tensorflow-connector_2.11-1.6.0.jar").getOrCreate()

2( 将数据帧另存为 TFRecord (我以数据集为例(:

df = spark.createDataFrame([(1, 120), (2, 130), (2, 140)], ['A', 'B'])
path='path/example.tfrecord'
df.write.format("tfrecords").mode("overwrite").option("recordType", "Example").save(path)

3( 将 tfrecord 文件加载到 tf.data API 中(为了简单起见,我只是将"A"作为一个功能(:

path2 = "path/example.tfrecord/*"
dataset=tf.data.TFRecordDataset(tf.compat.v1.gfile.Glob(path2))
def parse_func(buff):
features = {'A': tf.compat.v1.FixedLenFeature(shape=[5], dtype=tf.int64)}
tensor_dict = tf.compat.v1.parse_single_example(buff, features)
return tensor_dict['A']
train_dataset = dataset.map(parse_func).batch(1)

但是当我尝试打印数据集迭代器时:

for x in train_dataset:
print(x)

我收到以下错误:

2020-05-21 06:43:53.579843: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at iterator_ops.cc:941 : Data loss: corrupted record at 0
Traceback (most recent call last):
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1897, in execution_mode
2020-05-21 06:43:53.580090: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A.  Can't parse serialized Example.
2020-05-21 06:43:53.580567: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A.  Can't parse serialized Example.
yield
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 659, in _next_internal
output_shapes=self._flat_output_shapes)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 2479, in iterator_get_next_sync
_ops.raise_from_not_ok_status(e, name)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 6606, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNextSync]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/snap/pycharm-community/194/plugins/python-ce/helpers/pycharm/_jb_unittest_runner.py", line 35, in <module>
sys.exit(main(argv=args, module=None, testRunner=unittestpy.TeamcityTestRunner, buffer=not JB_DISABLE_BUFFERING))
File "/usr/lib/python3.6/unittest/main.py", line 94, in __init__
self.parseArgs(argv)
File "/usr/lib/python3.6/unittest/main.py", line 141, in parseArgs
self.createTests()
File "/usr/lib/python3.6/unittest/main.py", line 148, in createTests
self.module)
File "/usr/lib/python3.6/unittest/loader.py", line 219, in loadTestsFromNames
suites = [self.loadTestsFromName(name, module) for name in names]
File "/usr/lib/python3.6/unittest/loader.py", line 219, in <listcomp>
suites = [self.loadTestsFromName(name, module) for name in names]
File "/usr/lib/python3.6/unittest/loader.py", line 204, in loadTestsFromName
test = obj()
File "/home/patrizio/PycharmProjects/pyspark-config/tests/python/output/test_output.py", line 75, in test_TFRecord_new
for x in train_dataset:
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 630, in __next__
return self.next()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 674, in next
return self._next_internal()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 665, in _next_internal
return structure.from_compatible_tensor_list(self._element_spec, ret)
File "/usr/lib/python3.6/contextlib.py", line 99, in __exit__
self.gen.throw(type, value, traceback)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1900, in execution_mode
executor_new.wait()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/executor.py", line 67, in wait
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0

有谁知道如何处理这个问题?

提前非常感谢你。

我希望这仍然相关。

您的 glob 表达式不正确。Spark 在将示例保存到 TFRecod 时必须创建一个_SUCCESS文件。在模式中包含扩展

path2 = "path/example.tfrecord/*.tfrecord"

您还可以通过简单地评估来检查python将要读取的文件列表

tf.io.gfile.glob(path)

我会使用这个 API 而不是旧的compat.v1

tf.io.FixedLenFeature的形状也是错误的。每个值都是一个标量,而不是长度为 5 的向量。正确的形状很简单[].

def parse_func(buff):
features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
tensor_dict = tf.io.parse_single_example(buff, features)
return tensor_dict
train_dataset = dataset.map(parse_func).batch(3)

如果你真的想花哨,使用tf.io.parse_example更好,因为它执行矢量化解析。但是,您需要在解析之前进行批处理。

def parse_func(buff):
features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
tensor_dict = tf.io.parse_example(buff, features)
return tensor_dict
train_dataset = dataset.batch(3).map(parse_func)

通过使用 parse_example 批处理示例原型而不是直接使用此函数,可能会看到性能优势。(来源(