如何限制创建的线程数并等待主线程直到任何一个线程找到答案?

How to limit number of threads created and wait main thread until any one thread finds answer?

这是查找 LCM 和 HCF 之和等于该数字的第一对数字(1 除外)的代码。

import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

class PerfectPartition {
    static long gcd(long a, long b) {
        if (a == 0)
            return b;
        return gcd(b % a, a);
    }

    // method to return LCM of two numbers
    static long lcm(long a, long b) {
        return (a / gcd(a, b)) * b;
    }

    long[] getPartition(long n) {
        var ref = new Object() {
            long x;
            long y;
            long[] ret = null;
        };

        Thread mainThread = Thread.currentThread();
        ThreadGroup t = new ThreadGroup("InnerLoop");

        for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
            if (t.activeCount() < 256) {

                new Thread(t, () -> {
                    for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                        long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);
                        if (z == n) {
                            ref.ret = new long[]{ref.x, ref.y};

                            t.interrupt();
                            break;
                        }
                    }
                }, "Thread_" + ref.x).start();

                if (ref.ret != null) {
                    return ref.ret;
                }
            } else {
                ref.x--;
            }
        }//return new long[]{1, n - 2};

        return Objects.requireNonNullElseGet(ref.ret, () -> new long[]{1, n - 2});
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long[] partition = new PerfectPartition().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
}

我想在找到第一对后立即停止代码执行。但是,main 线程只是保留 运行 并打印 1n-1.
限制数量的最佳解决方案是什么?线程数(<256,因为n的范围是2max of long)?

预期输出 (n=4): 2 2
预期输出n=8):4 4

您可以使用线程池。类似于:

ExecutorService executor = Executors.newFixedThreadPool(256);

然后将任务(或可运行程序)安排到其中。

完成后,停止添加任务,并终止线程池(终止也会阻止向线程池添加新任务的能力)。

首先,您错过了在线程上调用“start”。

new Thread(t, () -> {
    ...
    ...
}, "Thread_" + ref.x).start();

关于您的问题,要限制您可以使用线程池的线程数,例如,Executors.newFixedThreadPool(int nThreads)。

并且要停止执行,您可以让主线程等待单个计数 CountDownLatch 并在工作线程中成功匹配时对锁存器进行倒计时,并在主线程中等待锁存器时关闭线程池完成。

如您所问,这是使用线程池和 CountDownLatch 的示例代码:

import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

public class LcmHcmSum {

    static long gcd(long a, long b) {
        if (a == 0)
            return b;
        return gcd(b % a, a);
    }

    // method to return LCM of two numbers
    static long lcm(long a, long b) {
        return (a / gcd(a, b)) * b;
    }
    
    long[] getPartition(long n) {
        singleThreadJobSubmitter.execute(() -> {
            for (int x = 2; x < (n + 2) / 2; x++) {
                    submitjob(n, x);
                    if(numberPair != null) break;  // match found, exit the loop
            }
            try {
                jobsExecutor.shutdown();  // process the already submitted jobs
                jobsExecutor.awaitTermination(10, TimeUnit.SECONDS);  // wait for the completion of the jobs
                
                if(numberPair == null) {  // no match found, all jobs processed, nothing more to do, count down the latch 
                    latch.countDown();
                }
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        
        try {
            latch.await();
            singleThreadJobSubmitter.shutdownNow();
            jobsExecutor.shutdownNow();
            
        } catch (InterruptedException e1) {
            e1.printStackTrace();
        }
        return Objects.requireNonNullElseGet(numberPair, () -> new long[]{1, n - 2});
    }

    private Future<?> submitjob(long n, long x) {
        return jobsExecutor.submit(() -> {
            for (int y = 2; y < (n + 2) / 2; y++) {
                long z = lcm(x, y) + gcd(x, y);
                if (z == n) {
                    synchronized(LcmHcmSum.class) {  numberPair = new long[]{x, y}; }
                    latch.countDown();
                    break;
                }
            }
        });
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long[] partition = new LcmHcmSum().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
    
    private static CountDownLatch latch = new CountDownLatch(1);
    private static ExecutorService jobsExecutor = Executors.newFixedThreadPool(4);
    private static volatile long[] numberPair = null;
    private static ExecutorService singleThreadJobSubmitter = Executors.newSingleThreadExecutor();      
    

}

What could be an optimal solution to limit the no. of threads (<256 as the range of n is 2 to max of long)?

首先,您应该考虑执行代码的硬件(例如, 核心数)和您要并行化的算法类型,即它是 CPU-bound?, memory-bound?, IO-bound, 等等.

您的代码是 CPU-bound,因此,从性能的角度来看,线程数量通常不会超过 运行系统中的可用核心数。一如既往地尽可能多地介绍个人资料。

其次,您需要以证明并行性合理的方式在线程之间分配工作,在您的情况下:

  for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
        if (t.activeCount() < 256) {

            new Thread(t, () -> {
                for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                    long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);
                    if (z == n) {
                        ref.ret = new long[]{ref.x, ref.y};

                        t.interrupt();
                        break;
                    }
                }
            }, "Thread_" + ref.x).start();

            if (ref.ret != null) {
                return ref.ret;
            }
        } else {
            ref.x--;
        }
    }//return new long[]{1, n - 2};

你有点做了,但是 IMO 以一种令人费解的方式; IMO 更容易显式地并行化循环, 即, 在线程之间拆分其迭代,并删除所有 ThreadGroup 相关逻辑。

第三,注意竞争条件,例如:

var ref = new Object() {
    long x;
    long y;
    long[] ret = null;
};

此对象在线程之间共享,并由它们更新,从而导致竞争条件。正如我们即将看到的那样,您实际上并不需要这样的共享对象。

所以让我们一步一步来:

首先,找出您应该使用 执行代码的线程数,即 与内核相同的线程数:

int cores = Runtime.getRuntime().availableProcessors();

定义并行工作(这是循环分布的可能示例):

public void run() {
    for (int x = 2; && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads) {
            long z = lcm(x, y) + gcd(x, y);
            if (z == n) {
                // do something 
            }
        }
    }
}

在下面的代码中,我们以 循环法 方式在线程之间拆分要并行完成的工作,如下图所示:

I want to stop the code execution as soon as the first pair is found.

有几种方法可以实现这一点。我将提供最简单的 IMO,尽管不是 最复杂的 。当已经找到结果时,您可以使用变量向线程发出信号,例如:

final AtomicBoolean found;

每个线程将共享相同的 AtomicBoolean 变量,以便在其中一个线程中执行的更改也对其他线程可见:

@Override
public void run() {
    for (int x = 2 ; !found.get() && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads)  {
            long z = lcm(x, y) + gcd(x, y);
            if (z == n) {
                synchronized (found) {
                    if(!found.get()) {
                        rest[0] = x;
                        rest[1] = y;
                        found.set(true);
                    }
                    return;
                }
            }
        }
    }
}

由于您要的是代码片段示例,因此这里是一个简单的非防弹(且未经过适当测试)运行 编码示例:

class ThreadWork implements Runnable{

    final long[] rest;
    final AtomicBoolean found;
    final int threadID;
    final int total_threads;
    final long n;

    ThreadWork(long[] rest, AtomicBoolean found, int threadID, int total_threads, long n) {
        this.rest = rest;
        this.found = found;
        this.threadID = threadID;
        this.total_threads = total_threads;
        this.n = n;
    }

    static long gcd(long a, long b) {
        return (a == 0) ? b : gcd(b % a, a);
    }

    static long lcm(long a, long b, long gcd) {
        return (a / gcd) * b;
    }

    @Override
    public void run() {
        for (int x = 2; !found.get() && x < (n + 2) / 2; x ++) {
            for (int y = 2 + threadID; !found.get() && y < (n + 2) / 2; y += total_threads) {
                long result = gcd(x, y);
                long z = lcm(x, y, result) + result;
                if (z == n) {
                    synchronized (found) {
                        if(!found.get()) {
                            rest[0] = x;
                            rest[1] = y;
                            found.set(true);
                        }
                        return;
                    }
                }
            }
        }
    }
}

class PerfectPartition {

    public static void main(String[] args) throws InterruptedException {
        Scanner sc = new Scanner(System.in);
        final long n = sc.nextLong();
       final int total_threads = Runtime.getRuntime().availableProcessors();

        long[] rest = new long[2];
        AtomicBoolean found = new AtomicBoolean();

        double startTime = System.nanoTime();
        Thread[] threads = new Thread[total_threads];
        for(int i = 0; i < total_threads; i++){
            ThreadWork task = new ThreadWork(rest, found, i, total_threads, n);
            threads[i] = new Thread(task);
            threads[i].start();
        }

        for(int i = 0; i < total_threads; i++){
            threads[i].join();
        }

        double estimatedTime = System.nanoTime() - startTime;
        System.out.println(rest[0] + " " + rest[1]);


        double elapsedTimeInSecond = estimatedTime / 1_000_000_000;
        System.out.println(elapsedTimeInSecond + " seconds");
    }
}

输出:

4 -> 2 2
8 -> 4 4

以此代码为灵感,提出最适合您要求的解决方案。在您完全理解这些基础知识之后,尝试使用更复杂的 Java 功能改进方法,例如 ExecutorsFuturesCountDownLatch.


新更新:顺序优化

查看gcd方法:

  static long gcd(long a, long b) {
        return (a == 0)? b : gcd(b % a, a);
  }

lcm方法:

static long lcm(long a, long b) {
    return (a / gcd(a, b)) * b;
}

以及它们的使用方式:

long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);

您可以通过不在 lcm 方法中再次调用 gcd(a, b) 来优化顺序代码。所以将 lcm 方法更改为:

static long lcm(long a, long b, long gcd) {
    return (a / gcd) * b;
}

long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);

long result = gcd(ref.x, ref.y)
long z = lcm(ref.x, ref.y, gcd) + gcd;

我在此答案中提供的代码已经反映了这些更改。