使用 RxJava 等待完成多个任务

Wait for completion of multiple tasks using RxJava

我有这个 class(非 Rx),它同时启动 1000 个生产者线程和 1000 个消费者线程,然后等待它们通过阻塞队列的简单实现交换预定义数量的消息。此过程完成后,观察者收到结果通知:

public class ProducerConsumerBenchmarkUseCase extends BaseObservable<ProducerConsumerBenchmarkUseCase.Listener> {

    public static interface Listener {
        void onBenchmarkCompleted(Result result);
    }

    public static class Result {
        private final long mExecutionTime;
        private final int mNumOfReceivedMessages;

        public Result(long executionTime, int numOfReceivedMessages) {
            mExecutionTime = executionTime;
            mNumOfReceivedMessages = numOfReceivedMessages;
        }

        public long getExecutionTime() {
            return mExecutionTime;
        }

        public int getNumOfReceivedMessages() {
            return mNumOfReceivedMessages;
        }
    }

    private static final int NUM_OF_MESSAGES = 1000;
    private static final int BLOCKING_QUEUE_CAPACITY = 5;

    private final Object LOCK = new Object();

    private final Handler mUiHandler = new Handler(Looper.getMainLooper());

    private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);

    private int mNumOfFinishedConsumers;

    private int mNumOfReceivedMessages;

    private long mStartTimestamp;


    public void startBenchmarkAndNotify() {

        synchronized (LOCK) {
            mNumOfReceivedMessages = 0;
            mNumOfFinishedConsumers = 0;
            mStartTimestamp = System.currentTimeMillis();
        }

        // watcher-reporter thread
        new Thread(() -> {
            synchronized (LOCK) {
                while (mNumOfFinishedConsumers < NUM_OF_MESSAGES) {
                    try {
                        LOCK.wait();
                    } catch (InterruptedException e) {
                        return;
                    }
                }
            }
            notifySuccess();
        }).start();

        // producers init thread
        new Thread(() -> {
            for (int i = 0; i < NUM_OF_MESSAGES; i++) {
                startNewProducer(i);
            }
        }).start();

        // consumers init thread
        new Thread(() -> {
            for (int i = 0; i < NUM_OF_MESSAGES; i++) {
                startNewConsumer();
            }
        }).start();
    }


    private void startNewProducer(final int index) {
        new Thread(() -> mBlockingQueue.put(index)).start();
    }

    private void startNewConsumer() {
        new Thread(() -> {
            int message = mBlockingQueue.take();
            synchronized (LOCK) {
                if (message != -1) {
                    mNumOfReceivedMessages++;
                }
                mNumOfFinishedConsumers++;
                LOCK.notifyAll();
            }
        }).start();
    }

    private void notifySuccess() {
        mUiHandler.post(() -> {
            Result result;
            synchronized (LOCK) {
                 result =
                        new Result(
                                System.currentTimeMillis() - mStartTimestamp,
                                mNumOfReceivedMessages
                        );
            }
            for (Listener listener : getListeners()) {
                listener.onBenchmarkCompleted(result);
            }
        });
    }


}

现在我想把它重构为Rx。到目前为止,我设法做到了这一点:

public class ProducerConsumerBenchmarkUseCase {

    public static class Result {
        private final long mExecutionTime;
        private final int mNumOfReceivedMessages;

        public Result(long executionTime, int numOfReceivedMessages) {
            mExecutionTime = executionTime;
            mNumOfReceivedMessages = numOfReceivedMessages;
        }

        public long getExecutionTime() {
            return mExecutionTime;
        }

        public int getNumOfReceivedMessages() {
            return mNumOfReceivedMessages;
        }
    }

    private static final int NUM_OF_MESSAGES = 1000;
    private static final int BLOCKING_QUEUE_CAPACITY = 5;

    private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);

    private final AtomicInteger mNumOfFinishedConsumers = new AtomicInteger(0);

    private final AtomicInteger mNumOfReceivedMessages = new AtomicInteger(0);

    private volatile long mStartTimestamp;


    public Maybe<Result> startBenchmark() {

        return Maybe.fromCallable(new Callable<Result>() {
            @Override
            public Result call() {

                mNumOfReceivedMessages.set(0);
                mNumOfFinishedConsumers.set(0);
                mStartTimestamp = System.currentTimeMillis();

                Observable.range(0, NUM_OF_MESSAGES)
                        .subscribeOn(Schedulers.io())
                        .observeOn(Schedulers.io())
                        .forEach(
                                index -> newProducer(index).subscribe()
                        );

                Observable.range(0, NUM_OF_MESSAGES)
                        .subscribeOn(Schedulers.io())
                        .observeOn(Schedulers.io())
                        .map(index -> newConsumer())
                        .doOnNext(completable -> completable.subscribe())
                        .flatMap(completable -> { return Observable.just(completable); })
                        .toList()
                        .blockingGet();


                return new Result(
                        System.currentTimeMillis() - mStartTimestamp,
                        mNumOfReceivedMessages.get()
                );

            }
        });

    }

    private Completable newProducer(final int index) {
        return Completable
                .fromAction(() -> mBlockingQueue.put(index))
                .subscribeOn(Schedulers.io());
    }

    private Completable newConsumer() {
        return Completable
                .fromAction(() -> {
                    int message = mBlockingQueue.take();
                    if (message != -1) {
                        mNumOfReceivedMessages.incrementAndGet();
                    }
                })
                .subscribeOn(Schedulers.io());
    }

}

此代码编译、运行甚至完成,但结果交换的消息数小于 1000,这意味着这里存在某种问题。

我做错了什么?

我不完全明白你为什么要进行这种类型的处理。此外,我认为不需要消息计数,因为您应该生成 1000 条消息。请注意,使用此类基准测试,您很可能会衡量系统创建 2000 个线程的速度。

public Maybe<Result> startBenchmark() {
    return 
        Flowable.range(0, NUM_OF_MESSAGES)
        .flatMap(id -> 
             Flowable.fromCallable(() -> id) // <-- generate message
             .subscribeOn(Schedulers.io())
        )
        .parallel(NUM_OF_MESSAGES)
        .runOn(Schedulers.io())
        .doOnNext(msg -> { })  // <-- process message
        .sequential()
        .count()
        .doOnSubscribe(s -> { mStartTimestamp = System.currentTimeMillis(); })
        .map(cnt -> new Result(System.currentTimeMillis() - mStartTimestamp, cnt))
        .toMaybe();
 }