关于pyspark:将Spark Dataframe字符串列拆分为多个列

Split Spark Dataframe string column into multiple columns

我见过很多人建议Dataframe.explode是执行此操作的一种有用方法,但是它导致的行数比原始数据帧多,这根本不是我想要的。 我只想做非常简单的Dataframe等效项:

1
rdd.map(lambda row: row + [row.my_str_col.split('-')])

它看起来像:

1
2
3
4
col1 | my_str_col
-----+-----------
  18 |  856-yygrm
 201 |  777-psgdg

并将其转换为:

1
2
3
4
col1 | my_str_col | _col3 | _col4
-----+------------+-------+------
  18 |  856-yygrm |   856 | yygrm
 201 |  777-psgdg |   777 | psgdg

我知道pyspark.sql.functions.split(),但是它导致嵌套数组列而不是像我想要的两个顶级列。

理想情况下,我也希望这些新列也被命名。


pyspark.sql.functions.split()是正确的方法-您只需要将嵌套的ArrayType列展平为多个顶级列。在这种情况下,每个数组仅包含2个项目,这非常简单。您只需使用Column.getItem()即可将数组的每个部分作为列本身进行检索:

1
2
3
split_col = pyspark.sql.functions.split(df['my_str_col'], '-')
df = df.withColumn('NAME1', split_col.getItem(0))
df = df.withColumn('NAME2', split_col.getItem(1))

结果将是:

1
2
3
4
col1 | my_str_col | NAME1 | NAME2
-----+------------+-------+------
  18 |  856-yygrm |   856 | yygrm
 201 |  777-psgdg |   777 | psgdg

我不确定在嵌套数组的行与行之间大小不相同的一般情况下如何解决此问题。


这是针对一般情况的解决方案,该解决方案无需使用collectudf s提前知道数组的长度。不幸的是,这仅适用于spark 2.1及更高版本,因为它需要posexplode功能。

假设您具有以下DataFrame:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
df = spark.createDataFrame(
    [
        [1, 'A, B, C, D'],
        [2, 'E, F, G'],
        [3, 'H, I'],
        [4, 'J']
    ]
    , ["num","letters"]
)
df.show()
#+---+----------+
#|num|   letters|
#+---+----------+
#|  1|A, B, C, D|
#|  2|   E, F, G|
#|  3|      H, I|
#|  4|         J|
#+---+----------+

拆分letters列,然后使用posexplode分解结果数组以及数组中的位置。接下来,使用pyspark.sql.functions.expr捕获此数组中索引为pos的元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import pyspark.sql.functions as f

df.select(
       "num",
        f.split("letters",",").alias("letters"),
        f.posexplode(f.split("letters",",")).alias("pos","val")
    )\
    .show()
#+---+------------+---+---+
#|num|     letters|pos|val|
#+---+------------+---+---+
#|  1|[A, B, C, D]|  0|  A|
#|  1|[A, B, C, D]|  1|  B|
#|  1|[A, B, C, D]|  2|  C|
#|  1|[A, B, C, D]|  3|  D|
#|  2|   [E, F, G]|  0|  E|
#|  2|   [E, F, G]|  1|  F|
#|  2|   [E, F, G]|  2|  G|
#|  3|      [H, I]|  0|  H|
#|  3|      [H, I]|  1|  I|
#|  4|         [J]|  0|  J|
#+---+------------+---+---+

现在,我们根据此结果创建两个新列。第一个是我们新列的名称,它将是letter与数组中的索引的串联。第二列将是数组中相应索引处的值。我们通过利用pyspark.sql.functions.expr的功能获得后者,该功能允许我们将列值用作参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
df.select(
       "num",
        f.split("letters",",").alias("letters"),
        f.posexplode(f.split("letters",",")).alias("pos","val")
    )\
    .drop("val")\
    .select(
       "num",
        f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
        f.expr("letters[pos]").alias("val")
    )\
    .show()
#+---+-------+---+
#|num|   name|val|
#+---+-------+---+
#|  1|letter0|  A|
#|  1|letter1|  B|
#|  1|letter2|  C|
#|  1|letter3|  D|
#|  2|letter0|  E|
#|  2|letter1|  F|
#|  2|letter2|  G|
#|  3|letter0|  H|
#|  3|letter1|  I|
#|  4|letter0|  J|
#+---+-------+---+

现在我们可以groupBynumpivot DataFrame。综上所述,我们得到:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
df.select(
       "num",
        f.split("letters",",").alias("letters"),
        f.posexplode(f.split("letters",",")).alias("pos","val")
    )\
    .drop("val")\
    .select(
       "num",
        f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
        f.expr("letters[pos]").alias("val")
    )\
    .groupBy("num").pivot("name").agg(f.first("val"))\
    .show()
#+---+-------+-------+-------+-------+
#|num|letter0|letter1|letter2|letter3|
#+---+-------+-------+-------+-------+
#|  1|      A|      B|      C|      D|
#|  3|      H|      I|   null|   null|
#|  2|      E|      F|      G|   null|
#|  4|      J|   null|   null|   null|
#+---+-------+-------+-------+-------+


这是另一种方法,以防您想用定界符分割字符串。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import pyspark.sql.functions as f

df = spark.createDataFrame([("1:a:2001",),("2:b:2002",),("3:c:2003",)],["value"])
df.show()
+--------+
|   value|
+--------+
|1:a:2001|
|2:b:2002|
|3:c:2003|
+--------+

df_split = df.select(f.split(df.value,":")).rdd.flatMap(
              lambda x: x).toDF(schema=["col1","col2","col3"])

df_split.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|2001|
|   2|   b|2002|
|   3|   c|2003|
+----+----+----+

我不认为这种向RDD的来回转换会减慢您的速度...
也不必担心最后一个架构规范:它是可选的,您可以避免将其推广到具有未知列大小的数据的解决方案。


我找到了针对一般不均匀情况的解决方案(或者当您获得通过.split()函数获得的嵌套列时):

1
2
3
4
5
6
7
8
9
10
import pyspark.sql.functions as f

@f.udf(StructType([StructField(col_3, StringType(), True),
                   StructField(col_4, StringType(), True)]))

 def splitCols(array):
    return array[0],  ''.join(array[1:len(array)])

 df = df.withColumn("name", splitCols(f.split(f.col("my_str_col"), '-')))\
        .select(df.columns+['name.*'])

基本上,您只需要选择所有前面的列+嵌套的'column_name。*',在这种情况下,您将它们作为两个顶级列。