使用 RxJava 等待完成多个任务
Wait for completion of multiple tasks using RxJava
我有这个 class(非 Rx),它同时启动 1000 个生产者线程和 1000 个消费者线程,然后等待它们通过阻塞队列的简单实现交换预定义数量的消息。此过程完成后,观察者收到结果通知:
public class ProducerConsumerBenchmarkUseCase extends BaseObservable<ProducerConsumerBenchmarkUseCase.Listener> {
public static interface Listener {
void onBenchmarkCompleted(Result result);
}
public static class Result {
private final long mExecutionTime;
private final int mNumOfReceivedMessages;
public Result(long executionTime, int numOfReceivedMessages) {
mExecutionTime = executionTime;
mNumOfReceivedMessages = numOfReceivedMessages;
}
public long getExecutionTime() {
return mExecutionTime;
}
public int getNumOfReceivedMessages() {
return mNumOfReceivedMessages;
}
}
private static final int NUM_OF_MESSAGES = 1000;
private static final int BLOCKING_QUEUE_CAPACITY = 5;
private final Object LOCK = new Object();
private final Handler mUiHandler = new Handler(Looper.getMainLooper());
private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);
private int mNumOfFinishedConsumers;
private int mNumOfReceivedMessages;
private long mStartTimestamp;
public void startBenchmarkAndNotify() {
synchronized (LOCK) {
mNumOfReceivedMessages = 0;
mNumOfFinishedConsumers = 0;
mStartTimestamp = System.currentTimeMillis();
}
// watcher-reporter thread
new Thread(() -> {
synchronized (LOCK) {
while (mNumOfFinishedConsumers < NUM_OF_MESSAGES) {
try {
LOCK.wait();
} catch (InterruptedException e) {
return;
}
}
}
notifySuccess();
}).start();
// producers init thread
new Thread(() -> {
for (int i = 0; i < NUM_OF_MESSAGES; i++) {
startNewProducer(i);
}
}).start();
// consumers init thread
new Thread(() -> {
for (int i = 0; i < NUM_OF_MESSAGES; i++) {
startNewConsumer();
}
}).start();
}
private void startNewProducer(final int index) {
new Thread(() -> mBlockingQueue.put(index)).start();
}
private void startNewConsumer() {
new Thread(() -> {
int message = mBlockingQueue.take();
synchronized (LOCK) {
if (message != -1) {
mNumOfReceivedMessages++;
}
mNumOfFinishedConsumers++;
LOCK.notifyAll();
}
}).start();
}
private void notifySuccess() {
mUiHandler.post(() -> {
Result result;
synchronized (LOCK) {
result =
new Result(
System.currentTimeMillis() - mStartTimestamp,
mNumOfReceivedMessages
);
}
for (Listener listener : getListeners()) {
listener.onBenchmarkCompleted(result);
}
});
}
}
现在我想把它重构为Rx。到目前为止,我设法做到了这一点:
public class ProducerConsumerBenchmarkUseCase {
public static class Result {
private final long mExecutionTime;
private final int mNumOfReceivedMessages;
public Result(long executionTime, int numOfReceivedMessages) {
mExecutionTime = executionTime;
mNumOfReceivedMessages = numOfReceivedMessages;
}
public long getExecutionTime() {
return mExecutionTime;
}
public int getNumOfReceivedMessages() {
return mNumOfReceivedMessages;
}
}
private static final int NUM_OF_MESSAGES = 1000;
private static final int BLOCKING_QUEUE_CAPACITY = 5;
private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);
private final AtomicInteger mNumOfFinishedConsumers = new AtomicInteger(0);
private final AtomicInteger mNumOfReceivedMessages = new AtomicInteger(0);
private volatile long mStartTimestamp;
public Maybe<Result> startBenchmark() {
return Maybe.fromCallable(new Callable<Result>() {
@Override
public Result call() {
mNumOfReceivedMessages.set(0);
mNumOfFinishedConsumers.set(0);
mStartTimestamp = System.currentTimeMillis();
Observable.range(0, NUM_OF_MESSAGES)
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.io())
.forEach(
index -> newProducer(index).subscribe()
);
Observable.range(0, NUM_OF_MESSAGES)
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.io())
.map(index -> newConsumer())
.doOnNext(completable -> completable.subscribe())
.flatMap(completable -> { return Observable.just(completable); })
.toList()
.blockingGet();
return new Result(
System.currentTimeMillis() - mStartTimestamp,
mNumOfReceivedMessages.get()
);
}
});
}
private Completable newProducer(final int index) {
return Completable
.fromAction(() -> mBlockingQueue.put(index))
.subscribeOn(Schedulers.io());
}
private Completable newConsumer() {
return Completable
.fromAction(() -> {
int message = mBlockingQueue.take();
if (message != -1) {
mNumOfReceivedMessages.incrementAndGet();
}
})
.subscribeOn(Schedulers.io());
}
}
此代码编译、运行甚至完成,但结果交换的消息数小于 1000,这意味着这里存在某种问题。
我做错了什么?
我不完全明白你为什么要进行这种类型的处理。此外,我认为不需要消息计数,因为您应该生成 1000 条消息。请注意,使用此类基准测试,您很可能会衡量系统创建 2000 个线程的速度。
public Maybe<Result> startBenchmark() {
return
Flowable.range(0, NUM_OF_MESSAGES)
.flatMap(id ->
Flowable.fromCallable(() -> id) // <-- generate message
.subscribeOn(Schedulers.io())
)
.parallel(NUM_OF_MESSAGES)
.runOn(Schedulers.io())
.doOnNext(msg -> { }) // <-- process message
.sequential()
.count()
.doOnSubscribe(s -> { mStartTimestamp = System.currentTimeMillis(); })
.map(cnt -> new Result(System.currentTimeMillis() - mStartTimestamp, cnt))
.toMaybe();
}
我有这个 class(非 Rx),它同时启动 1000 个生产者线程和 1000 个消费者线程,然后等待它们通过阻塞队列的简单实现交换预定义数量的消息。此过程完成后,观察者收到结果通知:
public class ProducerConsumerBenchmarkUseCase extends BaseObservable<ProducerConsumerBenchmarkUseCase.Listener> {
public static interface Listener {
void onBenchmarkCompleted(Result result);
}
public static class Result {
private final long mExecutionTime;
private final int mNumOfReceivedMessages;
public Result(long executionTime, int numOfReceivedMessages) {
mExecutionTime = executionTime;
mNumOfReceivedMessages = numOfReceivedMessages;
}
public long getExecutionTime() {
return mExecutionTime;
}
public int getNumOfReceivedMessages() {
return mNumOfReceivedMessages;
}
}
private static final int NUM_OF_MESSAGES = 1000;
private static final int BLOCKING_QUEUE_CAPACITY = 5;
private final Object LOCK = new Object();
private final Handler mUiHandler = new Handler(Looper.getMainLooper());
private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);
private int mNumOfFinishedConsumers;
private int mNumOfReceivedMessages;
private long mStartTimestamp;
public void startBenchmarkAndNotify() {
synchronized (LOCK) {
mNumOfReceivedMessages = 0;
mNumOfFinishedConsumers = 0;
mStartTimestamp = System.currentTimeMillis();
}
// watcher-reporter thread
new Thread(() -> {
synchronized (LOCK) {
while (mNumOfFinishedConsumers < NUM_OF_MESSAGES) {
try {
LOCK.wait();
} catch (InterruptedException e) {
return;
}
}
}
notifySuccess();
}).start();
// producers init thread
new Thread(() -> {
for (int i = 0; i < NUM_OF_MESSAGES; i++) {
startNewProducer(i);
}
}).start();
// consumers init thread
new Thread(() -> {
for (int i = 0; i < NUM_OF_MESSAGES; i++) {
startNewConsumer();
}
}).start();
}
private void startNewProducer(final int index) {
new Thread(() -> mBlockingQueue.put(index)).start();
}
private void startNewConsumer() {
new Thread(() -> {
int message = mBlockingQueue.take();
synchronized (LOCK) {
if (message != -1) {
mNumOfReceivedMessages++;
}
mNumOfFinishedConsumers++;
LOCK.notifyAll();
}
}).start();
}
private void notifySuccess() {
mUiHandler.post(() -> {
Result result;
synchronized (LOCK) {
result =
new Result(
System.currentTimeMillis() - mStartTimestamp,
mNumOfReceivedMessages
);
}
for (Listener listener : getListeners()) {
listener.onBenchmarkCompleted(result);
}
});
}
}
现在我想把它重构为Rx。到目前为止,我设法做到了这一点:
public class ProducerConsumerBenchmarkUseCase {
public static class Result {
private final long mExecutionTime;
private final int mNumOfReceivedMessages;
public Result(long executionTime, int numOfReceivedMessages) {
mExecutionTime = executionTime;
mNumOfReceivedMessages = numOfReceivedMessages;
}
public long getExecutionTime() {
return mExecutionTime;
}
public int getNumOfReceivedMessages() {
return mNumOfReceivedMessages;
}
}
private static final int NUM_OF_MESSAGES = 1000;
private static final int BLOCKING_QUEUE_CAPACITY = 5;
private final MyBlockingQueue mBlockingQueue = new MyBlockingQueue(BLOCKING_QUEUE_CAPACITY);
private final AtomicInteger mNumOfFinishedConsumers = new AtomicInteger(0);
private final AtomicInteger mNumOfReceivedMessages = new AtomicInteger(0);
private volatile long mStartTimestamp;
public Maybe<Result> startBenchmark() {
return Maybe.fromCallable(new Callable<Result>() {
@Override
public Result call() {
mNumOfReceivedMessages.set(0);
mNumOfFinishedConsumers.set(0);
mStartTimestamp = System.currentTimeMillis();
Observable.range(0, NUM_OF_MESSAGES)
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.io())
.forEach(
index -> newProducer(index).subscribe()
);
Observable.range(0, NUM_OF_MESSAGES)
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.io())
.map(index -> newConsumer())
.doOnNext(completable -> completable.subscribe())
.flatMap(completable -> { return Observable.just(completable); })
.toList()
.blockingGet();
return new Result(
System.currentTimeMillis() - mStartTimestamp,
mNumOfReceivedMessages.get()
);
}
});
}
private Completable newProducer(final int index) {
return Completable
.fromAction(() -> mBlockingQueue.put(index))
.subscribeOn(Schedulers.io());
}
private Completable newConsumer() {
return Completable
.fromAction(() -> {
int message = mBlockingQueue.take();
if (message != -1) {
mNumOfReceivedMessages.incrementAndGet();
}
})
.subscribeOn(Schedulers.io());
}
}
此代码编译、运行甚至完成,但结果交换的消息数小于 1000,这意味着这里存在某种问题。
我做错了什么?
我不完全明白你为什么要进行这种类型的处理。此外,我认为不需要消息计数,因为您应该生成 1000 条消息。请注意,使用此类基准测试,您很可能会衡量系统创建 2000 个线程的速度。
public Maybe<Result> startBenchmark() {
return
Flowable.range(0, NUM_OF_MESSAGES)
.flatMap(id ->
Flowable.fromCallable(() -> id) // <-- generate message
.subscribeOn(Schedulers.io())
)
.parallel(NUM_OF_MESSAGES)
.runOn(Schedulers.io())
.doOnNext(msg -> { }) // <-- process message
.sequential()
.count()
.doOnSubscribe(s -> { mStartTimestamp = System.currentTimeMillis(); })
.map(cnt -> new Result(System.currentTimeMillis() - mStartTimestamp, cnt))
.toMaybe();
}