使用pyspark根据自定义模式验证csv读取列



我为要读取的每个CSV文件都有一组自定义模式。我希望能够检测到任何额外的和丢失的列。

我写了一个脚本来比较数据帧中的列和列表中的列。然而,我没有任何必要的列列表,但我有自定义模式:

customSchema = StructType([
StructField('foo', StringType(), False),
StructField('bar', StringType(), True),
])

如何使用此自定义架构来获取所需列的列表?或者有更好的方法吗?比如如果架构不匹配,强制读取CSV返回错误?

现在我有这样的东西:

df = spark.read.csv(path, header=True, schema=customSchema)
# print with fake detect_missmatch function for the purpose of this example
print(detect_missmatch(df.columns, hardcoded_list_of_required_columns))

我希望避免有这个hardcoded_list_of_required_columns列表,或者至少能够从我之前定义的自定义模式中生成它。

也许有更好的方法,但这应该足够

import json
required_columns = [c for c in json.loads(df.schema.json())['fields'] if not c['nullable']]

最新更新