Spark Java UDAF 输入struct嵌套结构

时间:2022-07-23
本文章向大家介绍Spark Java UDAF 输入struct嵌套结构,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

Spark Java UDAF

前言

首先明确一点:UDAF不仅仅用于agg()算子中

虽然Spark3.0.0的官方文档1已对Spark Java UDAF进行了说明,并且有example代码。因此本文主要解决在实际开发过程中,遇到的2种问题:

  1. 混用Type-Safe和Untyped类型,导致出错
  2. 反序列化Entity时,字段对应不上的问题

以下逐一进行描述说明。

UDAF的实现

先说明下Spark Java UDAF的2种实现形式2。第一种是继承UserDefinedAggregateFunction类,实现里面的8个方法,这种方式在Spark3.0.0中已标记为Depressed。第二种是继承Aggregator类,实现6个方法。

实现这样一个UDAF,统计AddressEntity中street出现的次数和对city的求和。

AddressEntity.java

public class AddressEntity implements Serializable {
    private String city;
    private String street;
}

PersonAnalizeEntity.java

(由于数据量不大,我们用2个map分别记录street的词频和city的累积和)

package com.sogo.getimei.entity;

import lombok.Getter;
import lombok.Setter;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

/**
 * @Created by IntelliJ IDEA.
 * @author: liuzhixuan
 * @Date: 2020/8/8
 * @Time: 21:55
 * @des:
 */
@Setter
@Getter
public class PersonAnalizeEntity implements Serializable {
    // record the number of street display
    private Map<String, Integer> streetCountMap;
    // record the sum of city
    private Map<String, Integer> streetSumMap;

    public PersonAnalizeEntity() {
        this.streetCountMap = new HashMap<>();
        this.streetSumMap = new HashMap<>();
    }
}

AddressAnaliseUdaf.java

UDAF的代码实现

package com.sogo.getimei.udf;

import com.sogo.getimei.entity.AddressEntity;
import com.sogo.getimei.entity.PersonAnalizeEntity;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

import java.util.Map;
import java.util.Set;

/**
 * @Created by IntelliJ IDEA.
 * @author: liuzhixuan
 * @Date: 2020/8/9
 * @Time: 14:45
 * @des:
 */
// 继承 Aggregator类
public class AddressAnaliseUdaf extends Aggregator<AddressEntity, PersonAnalizeEntity, PersonAnalizeEntity> {
    // 初始化
    @Override
    public PersonAnalizeEntity zero() {
        return new PersonAnalizeEntity();
    }
    // 分区内计算
    @Override
    public PersonAnalizeEntity reduce(PersonAnalizeEntity b, AddressEntity addressEntity) {
        // 存在street,在StreetCountMap加1、StreetSumMap加和
        if (b.getStreetCountMap().containsKey(addressEntity.getStreet())) {
            b.getStreetCountMap().put(addressEntity.getStreet(),
                    b.getStreetCountMap().get(addressEntity.getStreet()) + 1);
            b.getStreetSumMap().put(addressEntity.getStreet(),
                    b.getStreetSumMap().get(addressEntity.getStreet()) + Integer.valueOf(addressEntity.getCity()));
        } else {
            b.getStreetCountMap().put(addressEntity.getStreet(), 1);
            b.getStreetSumMap().put(addressEntity.getStreet(), Integer.valueOf(addressEntity.getCity()));
        }
        return b;
    }

    // 分区间合并
    @Override
    public PersonAnalizeEntity merge(PersonAnalizeEntity b1, PersonAnalizeEntity b2) {
        Set<Map.Entry<String, Integer>> entries = b2.getStreetCountMap().entrySet();
        for (Map.Entry<String, Integer> entry : b2.getStreetCountMap().entrySet()) {
            if (b1.getStreetCountMap().containsKey(entry.getKey())) {
                b1.getStreetCountMap().put(entry.getKey(),
                        entry.getValue() + b1.getStreetCountMap().get(entry.getKey()));
                b1.getStreetSumMap().put(entry.getKey(),
                        b2.getStreetSumMap().get(entry.getKey()) + b1.getStreetSumMap().get(entry.getKey()));
            } else {
                b1.getStreetCountMap().put(entry.getKey(), entry.getValue());
                b1.getStreetSumMap().put(entry.getKey(), b2.getStreetSumMap().get(entry.getKey()));
            }
        }
        return b1;
    }

    // 最终输出的结果
    @Override
    public PersonAnalizeEntity finish(PersonAnalizeEntity reduction) {
        return reduction;
    }

    // 中间结果的schema
    @Override
    public Encoder<PersonAnalizeEntity> bufferEncoder() {
        return Encoders.bean(PersonAnalizeEntity.class);
    }
   
    // 最终输出结果的schema
    @Override
    public Encoder<PersonAnalizeEntity> outputEncoder() {
        return Encoders.bean(PersonAnalizeEntity.class);
    }
}

UDAF的调用执行

Type-Safe和Untyped类型是针对Aggregator而言的。

简单而言,Type-Safe类型是针对Dataset<Entity>类型的,有类型检查。Untyped针对Dataset<Row>类型,或者用于SparkSQL中。理清楚了它们的使用场景,就可以避免混用导致的错误。

测试数据

Dataset<Row> studyDs的数据样例与结构如下:

+----+---+----------------------------------+
|name|age|address                           |
+----+---+----------------------------------+
|liu1|90 |[[Chn, 99], [Math, 98], [Eng, 97]]|
|liu2|80 |[[Chn, 89], [Math, 88], [Eng, 87]]|
|liu3|70 |[[Chn, 79], [Math, 78], [Eng, 77]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
+----+---+----------------------------------+

root
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- address: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- street: string (nullable = true)
 |    |    |-- city: string (nullable = true)

Type-Safe UDAFs

代码实现

代码说明见注解部分

Dataset<PersonAnalizeEntity> aFinal = studyDs
        // IUDAF输入类型AddressEnitty,因此需要将List<AddressEnitty>拍平
        .selectExpr("explode(address) as address")
        // 这里非常关键,需要解析出AddressEntity的各字段,才能被反序列化
        .selectExpr("address.city as city", "address.street as street")
        // Typed-safe
        .as(Encoders.bean(AddressEntity.class))
        //  通过调用UDAF的toColumn即可进行聚合计算
        .select(new AddressAnaliseUdaf().toColumn());
aFinal.show(10,0);
aFinal.printSchema();

测试结果

测试结果符合预期

+-------------------------------+-------------------------------------+
|streetCountMap                 |streetSumMap                         |
+-------------------------------+-------------------------------------+
|[Chn -> 5, Math -> 5, Eng -> 5]|[Chn -> 405, Math -> 400, Eng -> 395]|
+-------------------------------+-------------------------------------+

root
 |-- streetCountMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = true)
 |-- streetSumMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = true)

常见问题

反序列化成bean对象时,如果不拆分出address struct的各子字段city、street,则会出现下面的错误:

org.apache.spark.sql.AnalysisException: cannot resolve '`city`' given input columns: [address]

Untyped UDAFs

代码实现

注册UDAF

// udaf参数:1. UDAF对象, 2. 输入类型的Encoder
spark.udf().register("AddressAnaliseUdaf", udaf(new AddressAnaliseUdaf(), Encoders.bean(AddressEntity.class)));

第一种调用方式:callUDF方式调用 (成功)

Dataset<Row> agg = studyDs
        .selectExpr("name", "explode(address) as address")
        .selectExpr("name", "address.city as city", "address.street as street")
        // 这里也很关键:输入的字段按字段名的字典序排序
        .agg(callUDF("AddressAnaliseUdaf", col("city"), col("street")).alias("cal_result"))
        .selectExpr("cal_result.streetCountMap as streetCountMap", "cal_result.streetSumMap as streetSumMap");
agg.show(10, 0);
agg.printSchema();

输出结果符合预期

+-------------------------------+-------------------------------------+
|streetCountMap                 |streetSumMap                         |
+-------------------------------+-------------------------------------+
|[Chn -> 5, Math -> 5, Eng -> 5]|[Chn -> 405, Math -> 400, Eng -> 395]|
+-------------------------------+-------------------------------------+

root
 |-- streetCountMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = true)
 |-- streetSumMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = true)

第二种调用方式:在SQL中调用

文章1中提供的demo是简单结构,这里想实现复杂嵌套的UDAF,终于解决了

尝试1(失败)

studyDs.selectExpr("explode(address) as address")
        .registerTempTable("study");
Dataset<Row> sqlRow = spark.sql("select AddressAnaliseUdaf(address) from study");

报错信息如下:

Caused by: org.apache.spark.sql.AnalysisException: cannot resolve 'AddressAnaliseUdaf(address)' due to data type mismatch: argument 1 requires string type, however, 'study.`a
ddress`' is of struct<city:string,street:string> type.; line 1 pos 7

让问题变得迷茫的报错。

尝试2(成功)

studyDs.createOrReplaceTempView("study");
// 同样,UDAF中需要输入AddressEntity的各字段
// 需要按照AddressEntity中定义的顺序排序(可以随意修改字段名)
Dataset<Row> sqlRow = spark.sql("SELECT AddressAnaliseUdaf(address.city,address.street) FROM (SELECT explode(address) AS address FROM study)");
sqlRow.show(10, 0);
sqlRow.printSchema();

输出结果符合预期

+------------------------------------------------------------------------+
|addressanaliseudaf(address.city AS `city`, address.street AS `street`)  |
+------------------------------------------------------------------------+
|[[Chn -> 5, Math -> 5, Eng -> 5], [Chn -> 405, Math -> 400, Eng -> 395]]|
+------------------------------------------------------------------------+

root
 |-- addressanaliseudaf(address.city AS `city`, address.street AS `street`): struct (nullable = true)
 |    |-- streetCountMap: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: integer (valueContainsNull = true)
 |    |-- streetSumMap: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: integer (valueContainsNull = true)

测试修改字段名: 可以修改字段名

Dataset<Row> sqlRow = spark.sql("SELECT AddressAnaliseUdaf(city1,street1) FROM (SELECT address.city as city1, address.street as street1 FROM (SELECT explode(address) AS address FROM study))");

输出结果符合预期

+------------------------------------------------------------------------+
|addressanaliseudaf(city1, street1)                                      |
+------------------------------------------------------------------------+
|[[Chn -> 5, Math -> 5, Eng -> 5], [Chn -> 405, Math -> 400, Eng -> 395]]|
+------------------------------------------------------------------------+

root
 |-- addressanaliseudaf(city1, street1): struct (nullable = true)
 |    |-- streetCountMap: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: integer (valueContainsNull = true)
 |    |-- streetSumMap: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: integer (valueContainsNull = true)

测试结果

测试结果符合预期

小结

实现Spark Java UDAFs,只需要继承Aggregator类并实现其方法;在Typed-Safe下,只要保证反序列化成Dataset Entity对象后,即可通过UDAF对象的toColumn方法实现聚合计算。在Untyped下,使用callFunction和SQL调用要注意输入的字段顺序。这2种方式都需要注意将Entity展开成字段,进行传递。

参考文献

1 User Defined Aggregate Functions (UDAFs) https://spark.apache.org/docs/3.0.0/sql-ref-functions-udf-aggregate.html

2 spark中自定义UDAF函数实现的两种方式 https://blog.csdn.net/weixin_43861104/article/details/107358874