Flink核心API进阶(二)

本篇文章主要学习Flink的核心API的进阶部分的笔记,包含自定义Source、自定义分区器、自定义Sink等都是经常会遇到的场景,另外介绍了Transformation 部分高级算子,从Flink Function 的宏观层级出发,通过实现顶层接口把自定义的功能交给Flink去执行,后续遇到了更复杂的自定义场景也会在本篇笔记中补充。

所有Flink函数类都有其Rich版本。它与常规函数的不同在于,可以获取运行环境的上下文,并拥有一些生命周期方法,所以可以实现更复杂的功能。也有意味着提供了更多的,更丰富的功能。

Rich Function有一个生命周期的概念,典型的生命周期方法有:

  • open()方法是rich function的初始化方法,当一个算子例如map或者filter被调用之前open()会被调用
  • close()方法是生命周期中的最后一个调用的方法,做一些清理工作
  • getRuntimeContext()方法提供了函数的RuntimeContext的一些信息,例如函数执行的并行度,任务的名字,以及state状态

对于RichMapFunction来说,每条数据转换的时候都会调用一次map,而open的调用次数跟并行度相关,如果设置了并行度为2,那么就会调用两次,否则默认次数就是CPU核心数。

自定义Source

通过自定义Source的方式很容易产生一些模拟数据,在测试的时候非常有用,默认实现的SourceFunction并行度为1,而且也只能为1。在DataStreamSource中存在检查并行度的操作,源码如下:

 1@Override
 2public DataStreamSource<T> setParallelism(int parallelism) {
 3    OperatorValidationUtils.validateParallelism(parallelism, isParallel);
 4    super.setParallelism(parallelism);
 5    return this;
 6}
 7
 8// OperatorValidationUtils.validateParallelism
 9public static void validateParallelism(int parallelism, boolean canBeParallel) {
10    Preconditions.checkArgument(canBeParallel || parallelism == 1, 
11        "The parallelism of non parallel operator must be 1.");
12    Preconditions.checkArgument(parallelism > 0 || parallelism == -1, 
13        "The parallelism of an operator must be at least 1, or ExecutionConfig.PARALLELISM_DEFAULT (use system default).");
14}

实现一个我们自己的模拟数据AccessLogSource:

 1public class AccessLogSource implements SourceFunction<AccessLog> {
 2    boolean running = true;
 3    @Override
 4    public void run(SourceContext<AccessLog> ctx) throws Exception {
 5        String[] domains = {"cn.tim", "tim.com", "pk.com"};
 6        Random random = new Random();
 7        while (running) {
 8            for (int i = 0; i < 10; i++) {
 9                AccessLog accessLog = new AccessLog();
10                accessLog.setTime(1234567L);
11                accessLog.setDomain(domains[random.nextInt(domains.length)]);
12                accessLog.setTraffic(random.nextDouble() + 1_000);
13                ctx.collect(accessLog);
14            }
15            Thread.sleep(5_000);
16        }
17    }
18
19    @Override
20    public void cancel() {
21        running = false;
22    }
23}
24
25
26public class AccessLogSourceV2 implements ParallelSourceFunction<AccessLog> {
27    ......
28}
29
30private static void test_source_01(StreamExecutionEnvironment env) {
31    // DataStreamSource<AccessLog> source = env.addSource(new AccessLogSource());
32    // DataStreamSource<AccessLog> source = env.addSource(new AccessLogSource()).setParallelism(2); error SourceFunction并行度为只能是1
33    DataStreamSource<AccessLog> source = env.addSource(new AccessLogSourceV2()).setParallelism(2); 
34    System.out.println(source.getParallelism()); // 2
35    source.print();
36}
37

当我们对AccessLogSource设置除1以外的并行度的时候,由于DataStreamSource中存在检查并行度,则会报错,当需要实现并行度大于1的DataStreamSource的时候,去实现并行ParallelSourceFunction即可。

使用MySQL作为DataSource,下面是数据库的表结构和数据:

 1use pk_flink_01;
 2
 3SET FOREIGN_KEY_CHECKS=0;
 4
 5-- ----------------------------
 6-- Table structure for student
 7-- ----------------------------
 8DROP TABLE IF EXISTS `student`;
 9CREATE TABLE `student` (
10  `id` int(11) NOT NULL AUTO_INCREMENT,
11  `name` varchar(255) NOT NULL,
12  `age` int(11) NOT NULL,
13  PRIMARY KEY (`id`)
14) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;
15
16-- ----------------------------
17-- Records of student
18-- ----------------------------
19INSERT INTO `student` VALUES ('1', 'flink', '22');
20INSERT INTO `student` VALUES ('2', 'hello', '23');
21INSERT INTO `student` VALUES ('3', 'world', '38');

定义数据Bean:

1public class Student {
2    private Integer id;
3    private String name;
4    private Integer age;
5    ...
6}

定义MySQLUtils用于获取MySQL连接和关闭MySQL连接:

 1public class MySQLUtils {
 2    public static Connection getConnection() {
 3        try {
 4            Class.forName("com.mysql.cj.jdbc.Driver");
 5            return DriverManager.getConnection(
 6                    "jdbc:mysql://192.168.31.86:3307/pk_flink_01",
 7                    "root", "123456");
 8        } catch (Exception e) {
 9            e.printStackTrace();
10        }
11        return null;
12    }
13
14    public static void closeConnection(Connection connection,
15                                       PreparedStatement preparedStatement) {
16        if(preparedStatement != null) {
17            try {
18                preparedStatement.close();
19            } catch (SQLException e) {
20                e.printStackTrace();
21            }
22        }
23        if(connection != null){
24            try {
25                connection.close();
26            } catch (SQLException e) {
27                e.printStackTrace();
28            }
29        }
30    }
31}

定义StudentSource作为数据源:

 1public class StudentSource extends RichSourceFunction<Student> {
 2    Connection connection;
 3    PreparedStatement preparedStatement;
 4
 5    @Override
 6    public void run(SourceContext ctx) throws Exception {
 7        ResultSet resultSet = preparedStatement.executeQuery();
 8        while (resultSet.next()){
 9            int id = resultSet.getInt("id");
10            String name = resultSet.getString("name");
11            int age = resultSet.getInt("age");
12            ctx.collect(new Student(id, name, age));
13        }
14    }
15
16    @Override
17    public void cancel() {
18
19    }
20
21    @Override
22    public void open(Configuration parameters) throws Exception {
23        super.open(parameters);
24        connection = MySQLUtils.getConnection();
25        if(connection != null) {
26            preparedStatement = connection.prepareStatement("select * from student");
27        }
28    }
29
30    @Override
31    public void close() throws Exception {
32        super.close();
33        MySQLUtils.closeConnection(connection, preparedStatement);
34    }
35}

现在使用StudentSource测试一下自定义MySQLSource:

1private static void test_source_mysql(StreamExecutionEnvironment env) {
2    DataStreamSource<Student> source = env.addSource(new StudentSource());
3    System.out.println(source.getParallelism());
4    source.print().setParallelism(1);
5}

同理自定义Source从其他的数据库、文件、MQ、Socket都是一样的操作。

Transformation 高级算子

Union

这个算子也是官方文档里有介绍:将两个或多个数据流合并,创建一个包含所有数据流中的所有元素的新数据流。如果将一个数据流与它本身结合,将得到结果流中的每个元素两次:

就拿上面的MySQL DataSource为例:

 1DataStreamSource<Student> source = env.addSource(new StudentSource());
 2source.union(source).print();
 3
 4// ------- 输出 -----------
 511> Student{name='world', id=3, age=38}
 67> Student{name='flink', id=1, age=22}
 78> Student{name='hello', id=2, age=23}
 89> Student{name='hello', id=2, age=23}
 910> Student{name='world', id=3, age=38}
106> Student{name='flink', id=1, age=22}
11
12Process finished with exit code 0

需要注意的是Union做多流合并的时候需要数据结构必须相同。

CoMap

CoMap与CoFlatMap类似于StreamDataSource的Map与FlatMap,只不过它们适用于 connected data stream,不同的DataSource将使用不同的处理器来处理,比如一个Source是上面的MySQL、一个是Socket:

 1public static void test_coMap(StreamExecutionEnvironment env){
 2    // String类型
 3    DataStreamSource<String> source1 = env.socketTextStream("192.168.31.86", 9527);
 4    // Student类型
 5    DataStreamSource<Student> source2 = env.addSource(new StudentSource()); 
 6    // 将两个流连接在一起
 7    ConnectedStreams<String, Student> connect = source1.connect(source2);
 8    // source1的类型、source2的类型、返回值的类型
 9    connect.map(new CoMapFunction<String, Student, String>() {
10        // 处理第一个流的业务逻辑
11        @Override
12        public String map1(String value) throws Exception {
13            return value + "-CoMap";
14        }
15        // 处理第二个流的业务逻辑
16        @Override
17        public String map2(Student value) throws Exception {
18            return value.getName();
19        }
20    }).print();
21}
22
23// --------------------
244> world
252> flink
263> hello
272> hello-CoMap
283> 999 888-CoMap

CoFlatMap

 1public static void test_coFlatMap(StreamExecutionEnvironment env){
 2    DataStreamSource<String> source1 = env.fromElements("a b c", "d e f");
 3    DataStreamSource<String> source2 = env.fromElements("1,2,3", "4,5,6");
 4    ConnectedStreams<String, String> connect = source1.connect(source2);
 5    connect.flatMap(new CoFlatMapFunction<String, String, String>() {
 6        @Override
 7        public void flatMap1(String value, Collector<String> out) throws Exception {
 8            String[] split = value.split(" ");
 9            for(String s: split) {
10                out.collect(s);
11            }
12        }
13
14        @Override
15        public void flatMap2(String value, Collector<String> out) throws Exception {
16            String[] split = value.split(",");
17            for(String s: split) {
18                out.collect(s);
19            }
20        }
21    }).print();
22}
23
24// ----------------------
258> d
266> 4
277> a
286> 5
298> e
305> 1
318> f
326> 6
337> b
345> 2
357> c
365> 3
37
38Process finished with exit code 0

自定义分区器

AccessLogSourceV2这恶鬼DataSource产生的AccessLog对象包含了"cn.tim", "tim.com", "pk.com"等三种域名,现在假设需要对这些数据进行分区操作:

 1public class AccessLogSourceV2 implements ParallelSourceFunction<AccessLog> {
 2    boolean running = true;
 3    @Override
 4    public void run(SourceContext<AccessLog> ctx) throws Exception {
 5        String[] domains = {"cn.tim", "tim.com", "pk.com"};
 6        Random random = new Random();
 7        while (running) {
 8            for (int i = 0; i < 10; i++) {
 9                AccessLog accessLog = new AccessLog();
10                accessLog.setTime(1234567L);
11                accessLog.setDomain(domains[random.nextInt(domains.length)]);
12                accessLog.setTraffic(random.nextDouble() + 1_000);
13                ctx.collect(accessLog);
14            }
15
16            Thread.sleep(5_000);
17        }
18    }
19
20    @Override
21    public void cancel() {
22        running = false;
23    }
24}

首先定义分区器 PkPartitioner.java:

 1import org.apache.flink.api.common.functions.Partitioner;
 2public class PkPartitioner implements Partitioner<String> {
 3	// 根据key分成三个区
 4    @Override
 5    public int partition(String key, int numPartitions) {
 6        System.out.println("numPartitions = " + numPartitions);
 7        if("cn.tim".equals(key)){
 8            return 0;
 9        }else if("tim.com".equals(key)){
10            return 1;
11        }else {
12            return 2;
13        }
14    }
15}

使用分区器:

 1StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 2env.setParallelism(3);
 3DataStreamSource<AccessLog> source = env.addSource(new AccessLogSourceV2());
 4
 5source.map(new MapFunction<AccessLog, Tuple2<String, AccessLog>>() {
 6    @Override
 7    public Tuple2<String, AccessLog> map(AccessLog value) throws Exception {
 8        return Tuple2.of(value.getDomain(), value);
 9    }
10})
11    .partitionCustom(new PkPartitioner(), 0)
12    .map(new MapFunction<Tuple2<String, AccessLog>, AccessLog>() {
13        @Override
14        public AccessLog map(Tuple2<String, AccessLog> value) throws Exception {
15            System.out.println("current thread id is:" + Thread.currentThread().getId() + ", value is " + value.f1);
16            return value.f1;
17        }
18    }).print();
19env.execute("PartitionerApp");

使用分区器可以让我们可以根据数据中某些属性划分,分别划给对应的子任务来处理数据。

自定义Sink

除了Flink本身自带的Sink,我们也可以实现自己的Sink用户数据写回,现在假设需要把数据写回到MySQL。以请求域名统计为例,我们最终要存储的数据是每个域名访问的时长总和:

1-- ----------------------------
2-- Table structure for pk_traffic
3-- ----------------------------
4DROP TABLE IF EXISTS `pk_traffic`;
5CREATE TABLE `pk_traffic` (
6  `domain` varchar(255) NOT NULL,
7  `traffic` double NOT NULL
8) ENGINE=InnoDB DEFAULT CHARSET=utf8;

这样就会有两种情况,其实每计算一次需要判断是否已经存在表中,如果存在就更新数据,如果不存在就插入数据

PkMySQLSink.java 就是我们自定义的用于把数据写回到MySQL中的Sink:

 1/**
 2 * domain - traffic
 3 */
 4public class PkMySQLSink extends RichSinkFunction<Tuple2<String, Double>> {
 5    Connection connection;
 6
 7    PreparedStatement insertPst;
 8    PreparedStatement updatePst;
 9
10    @Override
11    public void open(Configuration parameters) throws Exception {
12        super.open(parameters);
13
14        connection = MySQLUtils.getConnection();
15        if(connection == null) throw new RuntimeException("MySQL link failed!");
16        insertPst = connection.prepareStatement("insert into pk_traffic(domain, traffic) values (?,?)");
17        updatePst = connection.prepareStatement("update pk_traffic set traffic = ? where domain = ?");
18    }
19
20    @Override
21    public void close() throws Exception {
22        super.close();
23        if(insertPst != null) insertPst.close();
24        if(updatePst != null) updatePst.close();
25        if(connection != null) connection.close();
26    }
27
28    @Override
29    public void invoke(Tuple2<String, Double> value, Context context) throws Exception {
30        System.out.println("===========invoke==========" + value.f0 + "--->" + value.f1);
31        updatePst.setDouble(1, value.f1);
32        updatePst.setString(2, value.f0);
33        updatePst.execute();
34
35       if(updatePst.getUpdateCount() == 0){
36           insertPst.setString(1, value.f0);
37           insertPst.setDouble(2, value.f1);
38           insertPst.execute();
39       }
40    }
41}

使用该Sink完成数据写回MySQL的功能:

 1public static void toMySQL(StreamExecutionEnvironment env){
 2    DataStreamSource<String> source = env.readTextFile("data/access.log");
 3    SingleOutputStreamOperator<AccessLog> map = source.map((MapFunction<String, AccessLog>) s -> {
 4        String[] split = s.trim().split(",");
 5        if (split.length < 3) return null;
 6        Long time = Long.parseLong(split[0]);
 7        String domain = split[1];
 8        Double traffic = Double.parseDouble(split[2]);
 9        return new AccessLog(time, domain, traffic);
10    });
11
12    SingleOutputStreamOperator<AccessLog> traffic = map.keyBy((KeySelector<AccessLog, String>)
13                                                              AccessLog::getDomain).sum("traffic");
14    traffic.print();
15
16    // 数据写回 MySQL
17    traffic.map(new MapFunction<AccessLog, Tuple2<String, Double>>() {
18        @Override
19        public Tuple2<String, Double> map(AccessLog value) throws Exception {
20            return Tuple2.of(value.getDomain(), value.getTraffic());
21        }
22    }).addSink(new PkMySQLSink());
23}

每条数据都会调用invoke方法,这样的话每次数据变更都会把计算结果写回MySQL,不过在数据量比较大的情况下肯定不能每次每次数据变更都把计算结果写回MySQL,这个时候就需要用到Window机制了。

注意这几个函数的调用次数与时机,open是rich function的初始化方法,像invoke则会每条数据都调用一次。

下面来看看如何把数据写回Redis,其实和MySQL是大同小异的。

https://bahir.apache.org/docs/flink/current/flink-streaming-redis/

首先引入依赖 pom.xml:

1<dependency>
2    <groupId>org.apache.bahir</groupId>
3    <artifactId>flink-connector-redis_2.11</artifactId>
4    <version>1.0</version>
5</dependency>

实现写入Redis的Sink —— PkRedisSink.java

 1public class PkRedisSink implements RedisMapper<Tuple2<String, Double>> {
 2    @Override
 3    public RedisCommandDescription getCommandDescription() {
 4        return new RedisCommandDescription(RedisCommand.HSET, "pk-traffic");
 5    }
 6
 7    @Override
 8    public String getKeyFromData(Tuple2<String, Double> data) {
 9        return data.f0;
10    }
11
12    @Override
13    public String getValueFromData(Tuple2<String, Double> data) {
14        return data.f1 + "";
15    }
16}

使用和PkMySQLSink一样,直接替换掉PkMySQLSink即可!

1......
2//.addSink(new PkMySQLSink());
3.addSink(new RedisSink<Tuple2<String, Double>>(conf, new PkRedisSink()));