是否可以将pathlib.Path
对象与spark.read.parquet
和其他pyspark.sql.DataFrameReader
方法一起使用?
默认情况下不起作用:
>>> from pathlib import Path
>>> basedir = Path("/data")
>>> spark.read.parquet(basedir / "name.parquet")
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-5-cec8ced1bc5d> in <module>
----> 1 spark.read.parquet(basedir / "name.parquet")
<... a long traceback ...>
/opt/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_command_part(parameter, python_proxy_pool)
296 command_part += ";" + interface
297 else:
--> 298 command_part = REFERENCE_TYPE + parameter._get_object_id()
299
300 command_part += "n"
AttributeError: 'PosixPath' object has no attribute '_get_object_id'
我试着写py4j类型的转换器:
class PathConverter(object):
def can_convert(self, object):
return isinstance(object, Path)
def convert(self, object, gateway_client):
JavaString = JavaClass("java.lang.String", gateway_client)
return JavaString(str(object))
register_input_converter(PathConverter())
但我似乎误解了一些与字符串转换相关的概念/细节,因为py4j中的jvm.java.lang.String("string")
返回了pythonstr
对象:
>>> spark.read.parquet(basedir / "name.parquet")
<... a long traceback ...>
/opt/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
1306
1307 for temp_arg in temp_args:
-> 1308 temp_arg._detach()
AttributeError: 'str' object has no attribute '_detach'
我现在只有一个丑陋的解决方案:
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index fa3e829a88..7441a8ba8c 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -298,7 +298,7 @@ class DataFrameReader(OptionUtils):
modifiedAfter=modifiedAfter, datetimeRebaseMode=datetimeRebaseMode,
int96RebaseMode=int96RebaseMode)
- return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
+ return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths, converter=str)))
def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None,
recursiveFileLookup=None, modifiedBefore=None,
此外,通过查看readwriter.py
源代码,可以安全地对其版本的_to_seq
:进行monkeypatch
from pyspark.sql import readwriter
def converter(x):
if isinstance(x, PurePath):
return str(x)
return x
readwriter._to_seq = partial(readwriter._to_seq, converter=converter)
或者,更正确和完整的解决方法可能是直接对读取器/写入器方法进行猴痘:
@wraps(readwriter.DataFrameWriter.parquet)
def parquet(self, path, mode=None, partitionBy=None, compression=None):
return parquet.__wrapped__(self, str(path), mode=mode,
partitionBy=partitionBy,
compression=compression)
readwriter.DataFrameWriter.parquet = parquet