1 列内行(双精度)的参数最大值<array>



我有以下情况。

+--------------------+
|                   p|
+--------------------+
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
+--------------------+

这是 Row(( 对象的列表。

[Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06])]

我正在尝试将此列过滤为一个名为"maxClass"的新列,该列为所有行返回了 np.argmax(row([0]。以下是我最好的尝试,但我根本无法获得使用此软件包的语言学。

def f(row):
return np.argmax(np.array(row.p))[0]
results=probs.rdd.map(lambda x:f(x))
results

为了完整起见,正如 pault 在这里建议的那样,这是一个不使用 UDF 和 numpy 的解决方案。而是使用array_positionarray_max

import pyspark.sql.functions as f
df = spark.createDataFrame([
([0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],),
([0.9999841641213134, 0.99999, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],),
([0.9999841641213135, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],)]) 
.toDF("p")
df.select(
f.expr('array_position(cast(p as array<decimal(16, 16)>), cast(array_max(p) as decimal(16, 16))) - 1').alias("max_indx")
).show()
# +--------+
# |max_indx|
# +--------+
# |       0|
# |       1|
# |       0|
# +--------+

最新更新