我的数据帧中的一列如下。它的每一行都有一个浮动列表。我正试图通过减去最小值并除以最大值来对其进行归一化
X_Item_No
[0, 0, 0.09, 0.01, 0.013, 0.016, 0.018, 0.021]
[0, 0, 0.04, 0.31, 0.313, 0.216, 0.618, 0.028]
我尝试过的代码:
array_min = udf(lambda x: float(np.min(x)), FloatType())
array_max = udf(lambda x: float(np.max(x)), FloatType())
def test(cols, min, max):
return [(ele - min)/max for ele in cols]
df=(df.withColumn('X_col_min',array_min('X_Item_No'))
.withColumn('X_col_max',array_max('X_Item_No'))))
#display(df)
当我尝试应用上面的UDF并显示数据帧时,我得到了以下错误。回溯是一堆scala错误,对我调试没有用处。
PythonException: 'TypeError: float() argument must be a string or a number, not 'NoneType''
然后将"测试"功能应用于最小和最大
df2 = (df.withColumn('X_Item_No_new', f.udf(test, ArrayType(FloatType()))(df['X_Item_No'], df['X_col_min'],df['X_col_max']))
列X_Item_No
似乎有空值。由于np.min([None])
返回None,将None转换为float将引发该异常。
要获得数组中的最小值和最大值,最好使用spark的array_min
和array_max
函数,而不是使用udf。
import pyspark.sql.functions as F
df = spark.createDataFrame([[[0., 0., 0.09, 0.01, 0.013, 0.016, 0.018, 0.021]],
[[0., 0., 0.04, 0.31, 0.313, 0.216, 0.618, 0.028]]], schema="X_Item_No array<float>")
df = (df
.withColumn('X_col_min', F.array_min('X_Item_No'))
.withColumn('X_col_max', F.array_max('X_Item_No')))
df.show(truncate=False)
+--------------------------------------------------+---------+---------+
|X_Item_No |X_col_min|X_col_max|
+--------------------------------------------------+---------+---------+
|[0.0, 0.0, 0.09, 0.01, 0.013, 0.016, 0.018, 0.021]|0.0 |0.09 |
|[0.0, 0.0, 0.04, 0.31, 0.313, 0.216, 0.618, 0.028]|0.0 |0.618 |
+--------------------------------------------------+---------+---------+
如果使用的是spark版本2.4+,则可以使用SQLTRANSFORM
将函数应用于数组中的每个值。
df = df.withColumn('X_Item_No_new', F.expr("TRANSFORM(X_Item_No, value -> (value - X_col_min) / X_col_max)"))
df.show(truncate=False)
+--------------------------------------------------+---------+---------+----------------------------------------------------------------------------------------------------------------------+
|X_Item_No |X_col_min|X_col_max|X_Item_No_new |
+--------------------------------------------------+---------+---------+----------------------------------------------------------------------------------------------------------------------+
|[0.0, 0.0, 0.09, 0.01, 0.013, 0.016, 0.018, 0.021]|0.0 |0.09 |[0.0, 0.0, 1.0, 0.11111110421242565, 0.14444444168497025, 0.17777777915751486, 0.1999999834431549, 0.2333333209156995]|
|[0.0, 0.0, 0.04, 0.31, 0.313, 0.216, 0.618, 0.028]|0.0 |0.618 |[0.0, 0.0, 0.06472492069350876, 0.5016181504446381, 0.5064725053309027, 0.349514588623286, 1.0, 0.045307446896647376] |
+--------------------------------------------------+---------+---------+----------------------------------------------------------------------------------------------------------------------+