我想选择一个等于某个值的列。我在scala中做这个,遇到了一点麻烦。
我的代码
df.select(df("state")==="TX").show()
这将返回带有布尔值的状态列,而不仅仅是TX
我也试过
df.select(df("state")=="TX").show()
我遇到了同样的问题,下面的语法对我有效:
df.filter(df("state")==="TX").show()
我用的是Spark 1.6
还有一个简单的类似sql的选项。在Spark 1.6以下也应该工作。
df.filter("state = 'TX'")
这是一种指定类似sql的过滤器的新方法。有关支持的操作符的完整列表,请查看这个类。
您应该使用where
, select
是返回语句输出的投影,因此您获得布尔值。where
是一个保持数据框结构的过滤器,但只保留过滤器工作的数据。
在同一行,根据文档,你可以用三种不同的方式来写
// The following are equivalent:
peopleDf.filter($"age" > 15)
peopleDf.where($"age" > 15)
peopleDf($"age" > 15)
要获得否定,请执行以下操作…
df.filter(not( ..expression.. ))
如
df.filter(not($"state" === "TX"))
df.filter($"state" like "T%%")
模式匹配
df.filter($"state" === "TX")
或df.filter("state = 'TX'")
相等
工作于Spark V2.*
import sqlContext.implicits._
df.filter($"state" === "TX")
if需要与变量(如var)进行比较:
import sqlContext.implicits._
df.filter($"state" === var)
注:
import sqlContext.implicits._
我们可以在Dataframe中编写多个Filter/where条件。
例如:table1_df
.filter($"Col_1_name" === "buddy") // check for equal to string
.filter($"Col_2_name" === "A")
.filter(not($"Col_2_name".contains(" .sql"))) // filter a string which is not relevent
.filter("Col_2_name is not null") // no null filter
.take(5).foreach(println)
让我们创建一个示例数据集,并深入研究OP的代码无法工作的确切原因。
下面是我们的样本数据:
val df = Seq(
("Rockets", 2, "TX"),
("Warriors", 6, "CA"),
("Spurs", 5, "TX"),
("Knicks", 2, "NY")
).toDF("team_name", "num_championships", "state")
我们可以用show()
方法打印我们的数据集:
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Warriors| 6| CA|
| Spurs| 5| TX|
| Knicks| 2| NY|
+---------+-----------------+-----+
让我们检查df.select(df("state")==="TX").show()
的结果:
+------------+
|(state = TX)|
+------------+
| true|
| false|
| true|
| false|
+------------+
通过简单地附加一个列来更容易理解这个结果- df.withColumn("is_state_tx", df("state")==="TX").show()
:
+---------+-----------------+-----+-----------+
|team_name|num_championships|state|is_state_tx|
+---------+-----------------+-----+-----------+
| Rockets| 2| TX| true|
| Warriors| 6| CA| false|
| Spurs| 5| TX| true|
| Knicks| 2| NY| false|
+---------+-----------------+-----+-----------+
另一个代码OP尝试(df.select(df("state")=="TX").show()
)返回这个错误:
<console>:27: error: overloaded method value select with alternatives:
[U1](c1: org.apache.spark.sql.TypedColumn[org.apache.spark.sql.Row,U1])org.apache.spark.sql.Dataset[U1] <and>
(col: String,cols: String*)org.apache.spark.sql.DataFrame <and>
(cols: org.apache.spark.sql.Column*)org.apache.spark.sql.DataFrame
cannot be applied to (Boolean)
df.select(df("state")=="TX").show()
^
===
操作符在Column类中定义。Column类没有定义==
操作符,这就是这段代码出错的原因。
这是公认的答案:
df.filter(df("state")==="TX").show()
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Spurs| 5| TX|
+---------+-----------------+-----+
正如其他海报提到的,===
方法接受一个Any
类型的参数,所以这不是唯一有效的解决方案。这也可以工作,例如:
df.filter(df("state") === lit("TX")).show
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Spurs| 5| TX|
+---------+-----------------+-----+
Column equalTo
方法也可以使用:
df.filter(df("state").equalTo("TX")).show()
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Spurs| 5| TX|
+---------+-----------------+-----+
这个例子值得详细研究。Scala的语法有时看起来很神奇,特别是在调用方法时没有点表示法。对于未经训练的人来说,很难看出===
是Column
类中定义的方法!
下面是使用spark2.2+在json中获取数据的完整示例…
val myjson = "[{"name":"Alabama","abbreviation":"AL"},{"name":"Alaska","abbreviation":"AK"},{"name":"American Samoa","abbreviation":"AS"},{"name":"Arizona","abbreviation":"AZ"},{"name":"Arkansas","abbreviation":"AR"},{"name":"California","abbreviation":"CA"},{"name":"Colorado","abbreviation":"CO"},{"name":"Connecticut","abbreviation":"CT"},{"name":"Delaware","abbreviation":"DE"},{"name":"District Of Columbia","abbreviation":"DC"},{"name":"Federated States Of Micronesia","abbreviation":"FM"},{"name":"Florida","abbreviation":"FL"},{"name":"Georgia","abbreviation":"GA"},{"name":"Guam","abbreviation":"GU"},{"name":"Hawaii","abbreviation":"HI"},{"name":"Idaho","abbreviation":"ID"},{"name":"Illinois","abbreviation":"IL"},{"name":"Indiana","abbreviation":"IN"},{"name":"Iowa","abbreviation":"IA"},{"name":"Kansas","abbreviation":"KS"},{"name":"Kentucky","abbreviation":"KY"},{"name":"Louisiana","abbreviation":"LA"},{"name":"Maine","abbreviation":"ME"},{"name":"Marshall Islands","abbreviation":"MH"},{"name":"Maryland","abbreviation":"MD"},{"name":"Massachusetts","abbreviation":"MA"},{"name":"Michigan","abbreviation":"MI"},{"name":"Minnesota","abbreviation":"MN"},{"name":"Mississippi","abbreviation":"MS"},{"name":"Missouri","abbreviation":"MO"},{"name":"Montana","abbreviation":"MT"},{"name":"Nebraska","abbreviation":"NE"},{"name":"Nevada","abbreviation":"NV"},{"name":"New Hampshire","abbreviation":"NH"},{"name":"New Jersey","abbreviation":"NJ"},{"name":"New Mexico","abbreviation":"NM"},{"name":"New York","abbreviation":"NY"},{"name":"North Carolina","abbreviation":"NC"},{"name":"North Dakota","abbreviation":"ND"},{"name":"Northern Mariana Islands","abbreviation":"MP"},{"name":"Ohio","abbreviation":"OH"},{"name":"Oklahoma","abbreviation":"OK"},{"name":"Oregon","abbreviation":"OR"},{"name":"Palau","abbreviation":"PW"},{"name":"Pennsylvania","abbreviation":"PA"},{"name":"Puerto Rico","abbreviation":"PR"},{"name":"Rhode Island","abbreviation":"RI"},{"name":"South Carolina","abbreviation":"SC"},{"name":"South Dakota","abbreviation":"SD"},{"name":"Tennessee","abbreviation":"TN"},{"name":"Texas","abbreviation":"TX"},{"name":"Utah","abbreviation":"UT"},{"name":"Vermont","abbreviation":"VT"},{"name":"Virgin Islands","abbreviation":"VI"},{"name":"Virginia","abbreviation":"VA"},{"name":"Washington","abbreviation":"WA"},{"name":"West Virginia","abbreviation":"WV"},{"name":"Wisconsin","abbreviation":"WI"},{"name":"Wyoming","abbreviation":"WY"}]"
import spark.implicits._
val df = spark.read.json(Seq(myjson).toDS)
df.show
import spark.implicits._
val df = spark.read.json(Seq(myjson).toDS)
df.show
scala> df.show
+------------+--------------------+
|abbreviation| name|
+------------+--------------------+
| AL| Alabama|
| AK| Alaska|
| AS| American Samoa|
| AZ| Arizona|
| AR| Arkansas|
| CA| California|
| CO| Colorado|
| CT| Connecticut|
| DE| Delaware|
| DC|District Of Columbia|
| FM|Federated States ...|
| FL| Florida|
| GA| Georgia|
| GU| Guam|
| HI| Hawaii|
| ID| Idaho|
| IL| Illinois|
| IN| Indiana|
| IA| Iowa|
| KS| Kansas|
+------------+--------------------+
// equals matching
scala> df.filter(df("abbreviation") === "TX").show
+------------+-----+
|abbreviation| name|
+------------+-----+
| TX|Texas|
+------------+-----+
// or using lit
scala> df.filter(df("abbreviation") === lit("TX")).show
+------------+-----+
|abbreviation| name|
+------------+-----+
| TX|Texas|
+------------+-----+
//not expression
scala> df.filter(not(df("abbreviation") === "TX")).show
+------------+--------------------+
|abbreviation| name|
+------------+--------------------+
| AL| Alabama|
| AK| Alaska|
| AS| American Samoa|
| AZ| Arizona|
| AR| Arkansas|
| CA| California|
| CO| Colorado|
| CT| Connecticut|
| DE| Delaware|
| DC|District Of Columbia|
| FM|Federated States ...|
| FL| Florida|
| GA| Georgia|
| GU| Guam|
| HI| Hawaii|
| ID| Idaho|
| IL| Illinois|
| IN| Indiana|
| IA| Iowa|
| KS| Kansas|
+------------+--------------------+
only showing top 20 rows
Spark 2.4
与一个值比较:
df.filter(lower(trim($"col_name")) === "<value>").show()
与value集合进行比较:
df.filter($"col_name".isInCollection(new HashSet<>(Arrays.asList("value1", "value2")))).show()