使用pathlib.带有spark.read.parquet的路径



是否可以将pathlib.Path对象与spark.read.parquet和其他pyspark.sql.DataFrameReader方法一起使用?

默认情况下不起作用:

>>> from pathlib import Path
>>> basedir = Path("/data")
>>> spark.read.parquet(basedir / "name.parquet")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-cec8ced1bc5d> in <module>
----> 1 spark.read.parquet(basedir / "name.parquet")
<... a long traceback ...>
/opt/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_command_part(parameter, python_proxy_pool)
296             command_part += ";" + interface
297     else:
--> 298         command_part = REFERENCE_TYPE + parameter._get_object_id()
299 
300     command_part += "n"
AttributeError: 'PosixPath' object has no attribute '_get_object_id'

我试着写py4j类型的转换器:

class PathConverter(object):
def can_convert(self, object):
return isinstance(object, Path)
def convert(self, object, gateway_client):
JavaString = JavaClass("java.lang.String", gateway_client)
return JavaString(str(object))
register_input_converter(PathConverter())

但我似乎误解了一些与字符串转换相关的概念/细节,因为py4j中的jvm.java.lang.String("string")返回了pythonstr对象:

>>> spark.read.parquet(basedir / "name.parquet")
<... a long traceback ...>
/opt/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
1306 
1307         for temp_arg in temp_args:
-> 1308             temp_arg._detach()
AttributeError: 'str' object has no attribute '_detach'

我现在只有一个丑陋的解决方案:

diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index fa3e829a88..7441a8ba8c 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -298,7 +298,7 @@ class DataFrameReader(OptionUtils):
modifiedAfter=modifiedAfter, datetimeRebaseMode=datetimeRebaseMode,
int96RebaseMode=int96RebaseMode)

-        return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
+        return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths, converter=str)))

def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None,
recursiveFileLookup=None, modifiedBefore=None,

此外,通过查看readwriter.py源代码,可以安全地对其版本的_to_seq:进行monkeypatch

from pyspark.sql import readwriter
def converter(x):
if isinstance(x, PurePath):
return str(x)
return x
readwriter._to_seq = partial(readwriter._to_seq, converter=converter)

或者,更正确和完整的解决方法可能是直接对读取器/写入器方法进行猴痘:

@wraps(readwriter.DataFrameWriter.parquet)
def parquet(self, path, mode=None, partitionBy=None, compression=None):
return parquet.__wrapped__(self, str(path), mode=mode,
partitionBy=partitionBy,
compression=compression)
readwriter.DataFrameWriter.parquet = parquet

相关内容

  • 没有找到相关文章

最新更新