apache flink - 作为终止条件过滤

apache flink - filter as termination condition

我已经通过 k-means 为终止条件定义了一个过滤器。 如果我 运行 我的应用程序它总是只计算一次迭代。

我认为问题出在这里:

DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));

或者过滤器函数:

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

此致, 保罗

我的完整代码在这里:

public void run() {   
    //load properties
    Properties pro = new Properties();
    FileSystem fs = null;
    try {
        pro.load(FlinkMain.class.getResourceAsStream("/config.properties"));
        fs = FileSystem.get(new URI(pro.getProperty("hdfs.namenode")),new org.apache.hadoop.conf.Configuration());
    } catch (Exception e) {
        e.printStackTrace();
    }

    int maxIteration = Integer.parseInt(pro.getProperty("maxiterations"));
    String outputPath = fs.getHomeDirectory()+pro.getProperty("flink.output");
    // set up execution environment
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    // get input points
    DataSet<GeoTimeDataTupel> points = getPointDataSet(env);
    DataSet<GeoTimeDataCenter> centroids = null;
    try {
        centroids = getCentroidDataSet(env);
    } catch (Exception e1) {
        e1.printStackTrace();
    }
    // set number of bulk iterations for KMeans algorithm
    IterativeDataSet<GeoTimeDataCenter> loop = centroids.iterate(maxIteration);
    DataSet<GeoTimeDataCenter> newCentroids = points
        // compute closest centroid for each point
        .map(new SelectNearestCenter(this.getBenchmarkCounter())).withBroadcastSet(loop, "centroids")
        // count and sum point coordinates for each centroid
        .groupBy(0).reduceGroup(new CentroidAccumulator())
        // compute new centroids from point counts and coordinate sums
        .map(new CentroidAverager(this.getBenchmarkCounter()));
    // feed new centroids back into next iteration with termination condition
    DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));
    DataSet<Tuple2<Integer, GeoTimeDataTupel>> clusteredPoints = points
        // assign points to final clusters
        .map(new SelectNearestCenter(-1)).withBroadcastSet(finalCentroids, "centroids");
    // emit result
    clusteredPoints.writeAsCsv(outputPath+"/points", "\n", " ");
    finalCentroids.writeAsText(outputPath+"/centers");//print();
    // execute program
    try {
        env.execute("KMeans Flink");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

我认为问题出在过滤器功能上(取模您未发布的代码)。 Flink 的终止标准按以下方式工作:如果提供的终止 DataSet 为空,则满足终止标准。否则,如果未超过最大迭代次数,则开始下一次迭代。

Flink 的 filter 函数只保留那些 FilterFunction returns true 的元素。因此,通过 MyFilter 实现,您只能保持迭代前后的质心相同。这意味着如果所有质心都已更改,您将获得一个空的 DataSet,因此迭代终止。这显然与实际终止标准相反。终止标准应该是:只要质心发生变化,就继续使用 k-means。

您可以使用 coGroup 函数执行此操作,如果没有与前面的质心 DataSet 匹配的质心,您将发出元素。这类似于左外部联接,只是您丢弃非空匹配项。

public static void main(String[] args) throws Exception {
    // set up the execution environment
    final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

    DataSet<Element> oldDS = env.fromElements(new Element(1, "test"), new Element(2, "test"), new Element(3, "foobar"));
    DataSet<Element> newDS = env.fromElements(new Element(1, "test"), new Element(3, "foobar"), new Element(4, "test"));

    DataSet<Element> filtered = newDS.coGroup(oldDS).where("*").equalTo("*").with(new FilterCoGroup());

    filtered.print();
}

public static class FilterCoGroup implements CoGroupFunction<Element, Element, Element> {

    @Override
    public void coGroup(
            Iterable<Element> newElements,
            Iterable<Element> oldElements,
            Collector<Element> collector) throws Exception {

        List<Element> persistedElements = new ArrayList<Element>();

        for(Element element: oldElements) {
            persistedElements.add(element);
        }

        for(Element newElement: newElements) {
            boolean contained = false;

            for(Element oldElement: persistedElements) {
                if(newElement.equals(oldElement)){
                    contained = true;
                }
            }

            if(!contained) {
                collector.collect(newElement);
            }
        }
    }
}

public static class Element implements Key {
    private int id;
    private String name;

    public Element(int id, String name) {
        this.id = id;
        this.name = name;
    }

    public Element() {
        this(-1, "");
    }

    @Override
    public int hashCode() {
        return 31 + 7 * name.hashCode() + 11 * id;
    }

    @Override
    public boolean equals(Object obj) {
        if(obj instanceof Element) {
            Element element = (Element) obj;

            return id == element.id && name.equals(element.name);
        } else {
            return false;
        }
    }

    @Override
    public int compareTo(Object o) {
        if(o instanceof Element) {
            Element element = (Element) o;


            if(id == element.id) {
                return name.compareTo(element.name);
            } else {
                return id - element.id;
            }
        } else {
            throw new RuntimeException("Comparing incompatible types.");
        }
    }

    @Override
    public void write(DataOutputView dataOutputView) throws IOException {
        dataOutputView.writeInt(id);
        dataOutputView.writeUTF(name);
    }

    @Override
    public void read(DataInputView dataInputView) throws IOException {
        id = dataInputView.readInt();
        name = dataInputView.readUTF();
    }

    @Override
    public String toString() {
        return "(" + id + "; " + name + ")";
    }
}