Skip to content

Latest commit

 

History

History
74 lines (59 loc) · 1.9 KB

File metadata and controls

74 lines (59 loc) · 1.9 KB

StringIndexer

  StringIndexer将标签列的字符串编码为标签索引。这些索引是[0,numLabels),通过标签频率排序,所以频率最高的标签的索引为0。 如果输入列是数字,我们把它强转为字符串然后在编码。

例子

  假设我们有下面的DataFrame,它的列名是idcategory

 id | category
----|----------
 0  | a
 1  | b
 2  | c
 3  | a
 4  | a
 5  | c

  category是字符串列,拥有三个标签a,b,c。把category作为输入列,categoryIndex作为输出列,使用StringIndexer我们可以得到下面的结果。

 id | category | categoryIndex
----|----------|---------------
 0  | a        | 0.0
 1  | b        | 2.0
 2  | c        | 1.0
 3  | a        | 0.0
 4  | a        | 0.0
 5  | c        | 1.0

  a的索引号为0是因为它的频率最高,c次之,b最后。

  另外,StringIndexer处理未出现的标签的策略有两个:

  • 抛出一个异常(默认情况)
  • 跳过出现该标签的行

  让我们回到上面的例子,但是这次我们重用上面的StringIndexer到下面的数据集。

 id | category
----|----------
 0  | a
 1  | b
 2  | c
 3  | d

  如果我们没有为StringIndexer设置怎么处理未见过的标签或者设置为error,它将抛出异常,否则若设置为skip,它将得到下面的结果。

id  | category | categoryIndex
----|----------|---------------
 0  | a        | 0.0
 1  | b        | 2.0
 2  | c        | 1.0

  下面是程序调用的例子。

import org.apache.spark.ml.feature.StringIndexer

val df = spark.createDataFrame(
  Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category")

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")

val indexed = indexer.fit(df).transform(df)
indexed.show()