我为要读取的每个CSV文件都有一组自定义模式。我希望能够检测到任何额外的和丢失的列。
我写了一个脚本来比较数据帧中的列和列表中的列。然而,我没有任何必要的列列表,但我有自定义模式:
customSchema = StructType([
StructField('foo', StringType(), False),
StructField('bar', StringType(), True),
])
如何使用此自定义架构来获取所需列的列表?或者有更好的方法吗?比如如果架构不匹配,强制读取CSV返回错误?
现在我有这样的东西:
df = spark.read.csv(path, header=True, schema=customSchema)
# print with fake detect_missmatch function for the purpose of this example
print(detect_missmatch(df.columns, hardcoded_list_of_required_columns))
我希望避免有这个hardcoded_list_of_required_columns
列表,或者至少能够从我之前定义的自定义模式中生成它。
也许有更好的方法,但这应该足够
import json
required_columns = [c for c in json.loads(df.schema.json())['fields'] if not c['nullable']]