如何改进此算法以优化 运行 时间(分段查找点)
How can improve this algorithm to optimize the running time (find points in segments)
我得到了 2 个积分,第一个是段数 (Xi,Xj),第二个是可以或不能在这些段内的点数。
例如,输入可以是:
2 3
0 5
8 10
1 6 11
其中,在第一行中,2 表示“2 段”,3 表示“3 个点”。
这 2 个段是“0 到 5”和“8 到 10”,要查找的点是 1、6、11。
输出是
1 0 0
其中点 1 在段“0 到 5”中,点 6 和 11 不在任何段中。 如果一个点出现在多个段中,例如 3,则输出为 2。
原来的代码,只是一个双循环来搜索段之间的点。我使用了 Java 数组快速排序(经过修改,因此当它对段的端点进行排序时,也会对起点进行排序,因此 start[i] 和 end[i] 属于同一段 i)来提高双循环的速度,但它还不够。
下一个代码工作正常,但是当有太多段时它变得很慢:
public class PointsAndSegments {
private static int[] fastCountSegments(int[] starts, int[] ends, int[] points) {
sort(starts, ends);
int[] cnt2 = CountSegments(starts,ends,points);
return cnt2;
}
private static void dualPivotQuicksort(int[] a, int[] b, int left,int right, int div) {
int len = right - left;
if (len < 27) { // insertion sort for tiny array
for (int i = left + 1; i <= right; i++) {
for (int j = i; j > left && b[j] < b[j - 1]; j--) {
swap(a, b, j, j - 1);
}
}
return;
}
int third = len / div;
// "medians"
int m1 = left + third;
int m2 = right - third;
if (m1 <= left) {
m1 = left + 1;
}
if (m2 >= right) {
m2 = right - 1;
}
if (a[m1] < a[m2]) {
swap(a, b, m1, left);
swap(a, b, m2, right);
}
else {
swap(a, b, m1, right);
swap(a, b, m2, left);
}
// pivots
int pivot1 = b[left];
int pivot2 = b[right];
// pointers
int less = left + 1;
int great = right - 1;
// sorting
for (int k = less; k <= great; k++) {
if (b[k] < pivot1) {
swap(a, b, k, less++);
}
else if (b[k] > pivot2) {
while (k < great && b[great] > pivot2) {
great--;
}
swap(a, b, k, great--);
if (b[k] < pivot1) {
swap(a, b, k, less++);
}
}
}
// swaps
int dist = great - less;
if (dist < 13) {
div++;
}
swap(a, b, less - 1, left);
swap(a, b, great + 1, right);
// subarrays
dualPivotQuicksort(a, b, left, less - 2, div);
dualPivotQuicksort(a, b, great + 2, right, div);
// equal elements
if (dist > len - 13 && pivot1 != pivot2) {
for (int k = less; k <= great; k++) {
if (b[k] == pivot1) {
swap(a, b, k, less++);
}
else if (b[k] == pivot2) {
swap(a, b, k, great--);
if (b[k] == pivot1) {
swap(a, b, k, less++);
}
}
}
}
// subarray
if (pivot1 < pivot2) {
dualPivotQuicksort(a, b, less, great, div);
}
}
public static void sort(int[] a, int[] b) {
sort(a, b, 0, b.length);
}
public static void sort(int[] a, int[] b, int fromIndex, int toIndex) {
rangeCheck(a.length, fromIndex, toIndex);
dualPivotQuicksort(a, b, fromIndex, toIndex - 1, 3);
}
private static void rangeCheck(int length, int fromIndex, int toIndex) {
if (fromIndex > toIndex) {
throw new IllegalArgumentException("fromIndex > toIndex");
}
if (fromIndex < 0) {
throw new ArrayIndexOutOfBoundsException(fromIndex);
}
if (toIndex > length) {
throw new ArrayIndexOutOfBoundsException(toIndex);
}
}
private static void swap(int[] a, int[] b, int i, int j) {
int swap1 = a[i];
int swap2 = b[i];
a[i] = a[j];
b[i] = b[j];
a[j] = swap1;
b[j] = swap2;
}
private static int[] naiveCountSegments(int[] starts, int[] ends, int[] points) {
int[] cnt = new int[points.length];
for (int i = 0; i < points.length; i++) {
for (int j = 0; j < starts.length; j++) {
if (starts[j] <= points[i] && points[i] <= ends[j]) {
cnt[i]++;
}
}
}
return cnt;
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n, m;
n = scanner.nextInt();
m = scanner.nextInt();
int[] starts = new int[n];
int[] ends = new int[n];
int[] points = new int[m];
for (int i = 0; i < n; i++) {
starts[i] = scanner.nextInt();
ends[i] = scanner.nextInt();
}
for (int i = 0; i < m; i++) {
points[i] = scanner.nextInt();
}
//use fastCountSegments
int[] cnt = fastCountSegments(starts, ends, points);
for (int x : cnt) {
System.out.print(x + " ");
}
}
我认为问题出在 CountSegments() 方法中,但我不确定是否有其他方法可以解决它。据说,我应该使用分而治之的算法,但 4 天后,我想出了任何解决方案。
我找到了 a similar problem in CodeForces 但输出不同,而且大多数解决方案都是用 C++ 编写的。由于我只有 3 个月才开始学习 java,我认为我已经达到了知识的极限。
鉴于OP的约束,令n
为段数,m
为要查询的点数,其中n,m
<= 5*10^4 ,我可以想出一个O(nlg(n) + mlg(n))
的解决方案(这应该足以通过大多数在线判断)
由于每个查询都是一个验证问题:点是否可以被一些区间覆盖,是或否,我们不需要找出该点被覆盖了哪个/多少个间隔。
算法概要:
- 首先按起点对所有区间进行排序,如果相等则按长度(最右边的终点)排序
- 尝试合并间隔以获得一些 不相交 重叠间隔。例如(0,5), (2,9), (3,7), (3,5), (12,15) ,你会得到 (0,9), (12,15)。由于间隔已排序,因此可以在
O(n)
中贪婪地完成此操作
- 以上是预计算,现在对于每个点,我们使用不相交的区间进行查询。如果任何区间包含这样的点,则简单地进行二进制搜索,每个查询是
O(lg(n))
,我们得到 m
个点,所以总共 O(m lg(n))
结合整个算法,我们将得到一个O(nlg(n) + mlg(n))
算法
这是一个类似于@Shole的想法的实现:
public class SegmentsAlgorithm {
private PriorityQueue<int[]> remainSegments = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[0], o1[0]));
private SegmentWeight[] arraySegments;
public void addSegment(int begin, int end) {
remainSegments.add(new int[]{begin, end});
}
public void prepareArrayCache() {
List<SegmentWeight> preCalculate = new ArrayList<>();
PriorityQueue<int[]> currentSegmentsByEnds = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[1], o1[1]));
int begin = remainSegments.peek()[0];
while (!remainSegments.isEmpty() && remainSegments.peek()[0] == begin) {
currentSegmentsByEnds.add(remainSegments.poll());
}
preCalculate.add(new SegmentWeight(begin, currentSegmentsByEnds.size()));
int next;
while (!remainSegments.isEmpty()) {
if (currentSegmentsByEnds.isEmpty()) {
next = remainSegments.peek()[0];
} else {
next = Math.min(currentSegmentsByEnds.peek()[1], remainSegments.peek()[0]);
}
while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
currentSegmentsByEnds.poll();
}
while (!remainSegments.isEmpty() && remainSegments.peek()[0] == next) {
currentSegmentsByEnds.add(remainSegments.poll());
}
preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
}
while (!currentSegmentsByEnds.isEmpty()) {
next = currentSegmentsByEnds.peek()[1];
while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
currentSegmentsByEnds.poll();
}
preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
}
SegmentWeight[] arraySearch = new SegmentWeight[preCalculate.size()];
int i = 0;
for (SegmentWeight l : preCalculate) {
arraySearch[i++] = l;
}
this.arraySegments = arraySearch;
}
public int searchPoint(int p) {
int result = 0;
if (arraySegments != null && arraySegments.length > 0 && arraySegments[0].begin <= p) {
int index = Arrays.binarySearch(arraySegments, new SegmentWeight(p, 0), (o0, o1) -> Integer.compare(o0.begin, o1.begin));
if (index < 0){ // Bug fixed
index = - 2 - index;
}
if (index >= 0 && index < arraySegments.length) { // Protection added
result = arraySegments[index].weight;
}
}
return result;
}
public static void main(String[] args) {
SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
int[][] segments = {{0, 5},{3, 10},{8, 9},{14, 20},{12, 28}};
for (int[] segment : segments) {
algorithm.addSegment(segment[0], segment[1]);
}
algorithm.prepareArrayCache();
int[] points = {-1, 2, 4, 6, 11, 28};
for (int point: points) {
System.out.println(point + ": " + algorithm.searchPoint(point));
}
}
public static class SegmentWeight {
int begin;
int weight;
public SegmentWeight(int begin, int weight) {
this.begin = begin;
this.weight = weight;
}
}
}
它打印:
-1: 0
2: 1
4: 2
6: 1
11: 2
28: 0
已编辑:
public static void main(String[] args) {
SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int m = scanner.nextInt();
for (int i = 0; i < n; i++) {
algorithm.addSegment(scanner.nextInt(), scanner.nextInt());
}
algorithm.prepareArrayCache();
for (int i = 0; i < m; i++) {
System.out.print(algorithm.searchPoint(scanner.nextInt())+ " ");
}
System.out.println();
}
我得到了 2 个积分,第一个是段数 (Xi,Xj),第二个是可以或不能在这些段内的点数。
例如,输入可以是:
2 3
0 5
8 10
1 6 11
其中,在第一行中,2 表示“2 段”,3 表示“3 个点”。 这 2 个段是“0 到 5”和“8 到 10”,要查找的点是 1、6、11。 输出是
1 0 0
其中点 1 在段“0 到 5”中,点 6 和 11 不在任何段中。 如果一个点出现在多个段中,例如 3,则输出为 2。
原来的代码,只是一个双循环来搜索段之间的点。我使用了 Java 数组快速排序(经过修改,因此当它对段的端点进行排序时,也会对起点进行排序,因此 start[i] 和 end[i] 属于同一段 i)来提高双循环的速度,但它还不够。
下一个代码工作正常,但是当有太多段时它变得很慢:
public class PointsAndSegments {
private static int[] fastCountSegments(int[] starts, int[] ends, int[] points) {
sort(starts, ends);
int[] cnt2 = CountSegments(starts,ends,points);
return cnt2;
}
private static void dualPivotQuicksort(int[] a, int[] b, int left,int right, int div) {
int len = right - left;
if (len < 27) { // insertion sort for tiny array
for (int i = left + 1; i <= right; i++) {
for (int j = i; j > left && b[j] < b[j - 1]; j--) {
swap(a, b, j, j - 1);
}
}
return;
}
int third = len / div;
// "medians"
int m1 = left + third;
int m2 = right - third;
if (m1 <= left) {
m1 = left + 1;
}
if (m2 >= right) {
m2 = right - 1;
}
if (a[m1] < a[m2]) {
swap(a, b, m1, left);
swap(a, b, m2, right);
}
else {
swap(a, b, m1, right);
swap(a, b, m2, left);
}
// pivots
int pivot1 = b[left];
int pivot2 = b[right];
// pointers
int less = left + 1;
int great = right - 1;
// sorting
for (int k = less; k <= great; k++) {
if (b[k] < pivot1) {
swap(a, b, k, less++);
}
else if (b[k] > pivot2) {
while (k < great && b[great] > pivot2) {
great--;
}
swap(a, b, k, great--);
if (b[k] < pivot1) {
swap(a, b, k, less++);
}
}
}
// swaps
int dist = great - less;
if (dist < 13) {
div++;
}
swap(a, b, less - 1, left);
swap(a, b, great + 1, right);
// subarrays
dualPivotQuicksort(a, b, left, less - 2, div);
dualPivotQuicksort(a, b, great + 2, right, div);
// equal elements
if (dist > len - 13 && pivot1 != pivot2) {
for (int k = less; k <= great; k++) {
if (b[k] == pivot1) {
swap(a, b, k, less++);
}
else if (b[k] == pivot2) {
swap(a, b, k, great--);
if (b[k] == pivot1) {
swap(a, b, k, less++);
}
}
}
}
// subarray
if (pivot1 < pivot2) {
dualPivotQuicksort(a, b, less, great, div);
}
}
public static void sort(int[] a, int[] b) {
sort(a, b, 0, b.length);
}
public static void sort(int[] a, int[] b, int fromIndex, int toIndex) {
rangeCheck(a.length, fromIndex, toIndex);
dualPivotQuicksort(a, b, fromIndex, toIndex - 1, 3);
}
private static void rangeCheck(int length, int fromIndex, int toIndex) {
if (fromIndex > toIndex) {
throw new IllegalArgumentException("fromIndex > toIndex");
}
if (fromIndex < 0) {
throw new ArrayIndexOutOfBoundsException(fromIndex);
}
if (toIndex > length) {
throw new ArrayIndexOutOfBoundsException(toIndex);
}
}
private static void swap(int[] a, int[] b, int i, int j) {
int swap1 = a[i];
int swap2 = b[i];
a[i] = a[j];
b[i] = b[j];
a[j] = swap1;
b[j] = swap2;
}
private static int[] naiveCountSegments(int[] starts, int[] ends, int[] points) {
int[] cnt = new int[points.length];
for (int i = 0; i < points.length; i++) {
for (int j = 0; j < starts.length; j++) {
if (starts[j] <= points[i] && points[i] <= ends[j]) {
cnt[i]++;
}
}
}
return cnt;
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n, m;
n = scanner.nextInt();
m = scanner.nextInt();
int[] starts = new int[n];
int[] ends = new int[n];
int[] points = new int[m];
for (int i = 0; i < n; i++) {
starts[i] = scanner.nextInt();
ends[i] = scanner.nextInt();
}
for (int i = 0; i < m; i++) {
points[i] = scanner.nextInt();
}
//use fastCountSegments
int[] cnt = fastCountSegments(starts, ends, points);
for (int x : cnt) {
System.out.print(x + " ");
}
}
我认为问题出在 CountSegments() 方法中,但我不确定是否有其他方法可以解决它。据说,我应该使用分而治之的算法,但 4 天后,我想出了任何解决方案。 我找到了 a similar problem in CodeForces 但输出不同,而且大多数解决方案都是用 C++ 编写的。由于我只有 3 个月才开始学习 java,我认为我已经达到了知识的极限。
鉴于OP的约束,令n
为段数,m
为要查询的点数,其中n,m
<= 5*10^4 ,我可以想出一个O(nlg(n) + mlg(n))
的解决方案(这应该足以通过大多数在线判断)
由于每个查询都是一个验证问题:点是否可以被一些区间覆盖,是或否,我们不需要找出该点被覆盖了哪个/多少个间隔。
算法概要:
- 首先按起点对所有区间进行排序,如果相等则按长度(最右边的终点)排序
- 尝试合并间隔以获得一些 不相交 重叠间隔。例如(0,5), (2,9), (3,7), (3,5), (12,15) ,你会得到 (0,9), (12,15)。由于间隔已排序,因此可以在
O(n)
中贪婪地完成此操作
- 以上是预计算,现在对于每个点,我们使用不相交的区间进行查询。如果任何区间包含这样的点,则简单地进行二进制搜索,每个查询是
O(lg(n))
,我们得到m
个点,所以总共O(m lg(n))
结合整个算法,我们将得到一个O(nlg(n) + mlg(n))
算法
这是一个类似于@Shole的想法的实现:
public class SegmentsAlgorithm {
private PriorityQueue<int[]> remainSegments = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[0], o1[0]));
private SegmentWeight[] arraySegments;
public void addSegment(int begin, int end) {
remainSegments.add(new int[]{begin, end});
}
public void prepareArrayCache() {
List<SegmentWeight> preCalculate = new ArrayList<>();
PriorityQueue<int[]> currentSegmentsByEnds = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[1], o1[1]));
int begin = remainSegments.peek()[0];
while (!remainSegments.isEmpty() && remainSegments.peek()[0] == begin) {
currentSegmentsByEnds.add(remainSegments.poll());
}
preCalculate.add(new SegmentWeight(begin, currentSegmentsByEnds.size()));
int next;
while (!remainSegments.isEmpty()) {
if (currentSegmentsByEnds.isEmpty()) {
next = remainSegments.peek()[0];
} else {
next = Math.min(currentSegmentsByEnds.peek()[1], remainSegments.peek()[0]);
}
while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
currentSegmentsByEnds.poll();
}
while (!remainSegments.isEmpty() && remainSegments.peek()[0] == next) {
currentSegmentsByEnds.add(remainSegments.poll());
}
preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
}
while (!currentSegmentsByEnds.isEmpty()) {
next = currentSegmentsByEnds.peek()[1];
while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
currentSegmentsByEnds.poll();
}
preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
}
SegmentWeight[] arraySearch = new SegmentWeight[preCalculate.size()];
int i = 0;
for (SegmentWeight l : preCalculate) {
arraySearch[i++] = l;
}
this.arraySegments = arraySearch;
}
public int searchPoint(int p) {
int result = 0;
if (arraySegments != null && arraySegments.length > 0 && arraySegments[0].begin <= p) {
int index = Arrays.binarySearch(arraySegments, new SegmentWeight(p, 0), (o0, o1) -> Integer.compare(o0.begin, o1.begin));
if (index < 0){ // Bug fixed
index = - 2 - index;
}
if (index >= 0 && index < arraySegments.length) { // Protection added
result = arraySegments[index].weight;
}
}
return result;
}
public static void main(String[] args) {
SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
int[][] segments = {{0, 5},{3, 10},{8, 9},{14, 20},{12, 28}};
for (int[] segment : segments) {
algorithm.addSegment(segment[0], segment[1]);
}
algorithm.prepareArrayCache();
int[] points = {-1, 2, 4, 6, 11, 28};
for (int point: points) {
System.out.println(point + ": " + algorithm.searchPoint(point));
}
}
public static class SegmentWeight {
int begin;
int weight;
public SegmentWeight(int begin, int weight) {
this.begin = begin;
this.weight = weight;
}
}
}
它打印:
-1: 0
2: 1
4: 2
6: 1
11: 2
28: 0
已编辑:
public static void main(String[] args) {
SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int m = scanner.nextInt();
for (int i = 0; i < n; i++) {
algorithm.addSegment(scanner.nextInt(), scanner.nextInt());
}
algorithm.prepareArrayCache();
for (int i = 0; i < m; i++) {
System.out.print(algorithm.searchPoint(scanner.nextInt())+ " ");
}
System.out.println();
}