使用Spark Dataframe列中的数据作为另一个列表达式的条件或输入



我有一个操作,我想在PySpark 2.0中执行,这很容易作为df.rdd.map执行,但是由于性能原因,我更愿意留在Dataframe执行引擎中,我想找到一种方法,只使用Dataframe操作来实现这一点。

在rdd风格中,操作是这样的:
def precision_formatter(row):
    formatter = "%.{}f".format(row.precision)
    return row + [formatter % row.amount_raw / 10 ** row.precision]
df = df.rdd.map(precision_formatter)

基本上,我有一个列告诉我,对于每一行,我的字符串格式化操作的精度应该是多少,我想有选择地将'amount_raw'列格式化为字符串,这取决于该精度。

我不知道如何使用一个或多个列的内容作为另一个Column操作的输入。我最接近的建议是使用Column.when和一组外部定义的布尔操作,这些操作对应于列中可能的布尔条件/情况集。

在这个特定的情况下,例如,如果您可以获得(或者更好的是,已经拥有)row.precision的所有可能值,那么您可以遍历该集合并对集合中的每个值应用Column.when操作。我相信这个集合可以用df.select('precision').distinct().collect()得到。

因为pyspark.sql.functions.whenColumn.when操作本身返回一个Column对象,你可以遍历集合中的元素(无论它是如何获得的),并继续以编程方式"追加"when操作,直到耗尽集合:

import pyspark.sql.functions as PSF
def format_amounts_with_precision(df, all_precisions_set):
    amt_col = PSF.when(df['precision'] == 0, df['amount_raw'].cast(StringType()))
    for precision in all_precisions_set:
        if precision != 0:  # this is a messy way of having a base case above
            fmt_str = '%.{}f'.format(precision)
            amt_col = amt_col.when(df['precision'] == precision,
                           PSF.format_string(fmt_str, df['amount_raw'] / 10 ** precision)
    return df.withColumn('amount', amt_col)

您可以使用python UDF来完成。它们可以接受尽可能多的输入值(来自一行的列的值)并输出单个输出值。它看起来像这样:

from pyspark.sql import types as T, functions as F
from pyspark.sql.function import udf, col
# Create example data frame
schema = T.StructType([
    T.StructField('precision', T.IntegerType(), False),
    T.StructField('value', T.FloatType(), False)
])
data = [
    (1, 0.123456),
    (2, 0.123456),
    (3, 0.123456)
]
rdd = sc.parallelize(data)
df = sqlContext.createDataFrame(rdd, schema)
# Define UDF and apply it
def format_func(precision, value):
    format_str = "{:." + str(precision) + "f}"
    return format_str.format(value)
format_udf = F.udf(format_func, T.StringType())
new_df = df.withColumn('formatted', format_udf('precision', 'value'))
new_df.show()

另外,如果您希望使用全局值而不是列精度值,那么您可以在像这样调用它时使用lit(..)函数:

new_df = df.withColumn('formatted', format_udf(F.lit(2), 'value'))

相关内容

  • 没有找到相关文章

最新更新