如何使用线段树和二进制搜索解决加权 Activity 选择?
How to solve weighted Activity selection with use of Segment Trees and Binary search?
给定 N 个职位,其中每个职位都由以下三个元素表示。
1) 开始时间
2) 完成时间。
3) 利润或相关价值。
找到工作的最大利润子集,使得子集中没有两个工作重叠。
我知道一个复杂度为 O(N^2) 的动态规划解决方案(接近于 LIS,在 LIS 中我们只需要检查可以合并当前区间的先前元素,并采用合并后给出的区间最大值直到第 i 个元素)。这个解决方案可以进一步改进为 O(N*log N ) 使用二进制搜索和简单排序!
但是我的朋友告诉我,它甚至可以通过使用线段树和二分查找来解决!我不知道我将在哪里使用线段树以及如何使用。??
你能帮忙吗?
根据要求,抱歉没有评论
我正在做的是根据起始索引进行排序,通过合并之前的间隔和它们的最大可获得值,将最大可获得值存储到 DP[i] 中的 i !
void solve()
{
int n,i,j,k,high;
scanf("%d",&n);
pair < pair < int ,int>, int > arr[n+1];// first pair represents l,r and int alone shows cost
int dp[n+1];
memset(dp,0,sizeof(dp));
for(i=0;i<n;i++)
scanf("%d%d%d",&arr[i].first.first,&arr[i].first.second,&arr[i].second);
std::sort(arr,arr+n); // by default sorting on the basis of starting index
for(i=0;i<n;i++)
{
high=arr[i].second;
for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high , dp[j]+arr[i].second);
}
dp[i]=high;
}
for(i=0;i<n;i++)
dp[n-1]=std::max(dp[n-1],dp[i]);
printf("%d\n",dp[n-1]);
}
int main()
{solve();return 0;}
编辑:
我的工作代码最终花了我 3 个小时来调试它!此外,由于较大的常量和错误的实现,此代码比二进制搜索和排序代码慢:P(仅供参考)
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<cstring>
#include<iostream>
#include<climits>
#define lc(idx) (2*idx+1)
#define rc(idx) (2*idx+2)
#define mid(l,r) ((l+r)/2)
using namespace std;
int Tree[4*2*10000-1];
void update(int L,int R,int qe,int idx,int value)
{
if(value>Tree[0])
Tree[0]=value;
while(L<R)
{
if(qe<= mid(L,R))
{
idx=lc(idx);
R=mid(L,R);
}
else
{
idx=rc(idx);
L=mid(L,R)+1;
}
if(value>Tree[idx])
Tree[idx]=value;
}
return ;
}
int Get(int L,int R,int idx,int q)
{
if(q<L )
return 0;
if(R<=q)
return Tree[idx];
return max(Get(L,mid(L,R),lc(idx),q),Get(mid(L,R)+1,R,rc(idx),q));
}
bool cmp(pair < pair < int , int > , int > A,pair < pair < int , int > , int > B)
{
return A.first.second< B.first.second;
}
int main()
{
int N,i;
scanf("%d",&N);
pair < pair < int , int > , int > P[N];
vector < int > V;
for(i=0;i<N;i++)
{
scanf("%d%d%d",&P[i].first.first,&P[i].first.second,&P[i].second);
V.push_back(P[i].first.first);
V.push_back(P[i].first.second);
}
sort(V.begin(),V.end());
for(i=0;i<N;i++)
{
int &l=P[i].first.first,&r=P[i].first.second;
l=lower_bound(V.begin(),V.end(),l)-V.begin();
r=lower_bound(V.begin(),V.end(),r)-V.begin();
}
sort(P,P+N,cmp);
int ans=0;
memset(Tree,0,sizeof(Tree));
for(i=0;i<N;i++)
{
int aux=Get(0,2*N-1,0,P[i].first.first)+P[i].second;
if(aux>ans)
ans=aux;
update(0,2*N-1,P[i].first.second,0,ans);
}
printf("%d\n",ans);
return 0;
}
high=arr[i].second;
for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high, dp[j]+arr[i].second);
}
dp[i]=high;
这可以在 O(log n)
中用线段树完成。
首先,让我们重写一下。您取的最大值有点复杂,因为它取了涉及 i
和 j
的总和的最大值。但是i
这部分是常量,所以我们把它拿出来。
high=dp[0];
for(j=1;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high, dp[j]);
}
dp[i]=high + arr[i].second;
很好,现在我们已将问题简化为从满足您的 if
条件的值中确定 [0, i - 1]
中的最大值。
如果我们没有if
,那将是线段树的简单应用。
现在有两个选择。
1.为线段树处理 O(log V)
查询时间和 O(V)
内存
其中 V
是间隔端点的最大大小。
您可以构建一个线段树,在您移动 i
时将区间起点插入其中。然后查询值的范围。像这样,线段树被初始化为 -infinity
并且大小为 O(V)
.
Update(node, index, value):
if node.associated_interval == [index, index]:
node.max = value
return
if index in node.left.associated_interval:
Update(node.left, index, value)
else:
Update(node.right, index, value)
node.max = max(node.left.max, node.right.max)
Query(node, left, right):
if [left, right] does not intersect node.associated_interval:
return -infinity
if node.associated_interval included in [left, right]:
return node.max
return max(Query(node.left, left, right),
Query(node.right, left, right))
[...]
high=Query(tree, 0, arr[i].first.first)
dp[i]=high + arr[i].second;
Update(tree, arr[i].first.first, dp[i])
2。减少到 O(log n)
查询时间和 O(n)
段树的内存
由于间隔的数量可能明显少于它们的长度,因此有理由认为我们可以以某种方式更好地对它们进行编码,因此它们的长度也是 O(n)
。确实,我们可以。
这涉及在 [1, 2*n]
范围内标准化您的间隔。考虑以下区间
8 100
3 50
90 92
让我们把它们画成一条直线。它们看起来像这样:
3 8 50 90 92 100
现在用它们的索引替换它们中的每一个:
1 2 3 4 5 6
3 8 50 90 92 100
并写下你的新间隔:
2 6
1 3
4 5
请注意,它们保留了初始间隔的属性:相同的重叠,相同的相互包含等。
这可以通过排序来完成。您现在可以应用相同的线段树算法,除了您为大小 2*n
.
声明线段树
给定 N 个职位,其中每个职位都由以下三个元素表示。
1) 开始时间
2) 完成时间。
3) 利润或相关价值。
找到工作的最大利润子集,使得子集中没有两个工作重叠。
我知道一个复杂度为 O(N^2) 的动态规划解决方案(接近于 LIS,在 LIS 中我们只需要检查可以合并当前区间的先前元素,并采用合并后给出的区间最大值直到第 i 个元素)。这个解决方案可以进一步改进为 O(N*log N ) 使用二进制搜索和简单排序!
但是我的朋友告诉我,它甚至可以通过使用线段树和二分查找来解决!我不知道我将在哪里使用线段树以及如何使用。??
你能帮忙吗?
根据要求,抱歉没有评论
我正在做的是根据起始索引进行排序,通过合并之前的间隔和它们的最大可获得值,将最大可获得值存储到 DP[i] 中的 i !
void solve()
{
int n,i,j,k,high;
scanf("%d",&n);
pair < pair < int ,int>, int > arr[n+1];// first pair represents l,r and int alone shows cost
int dp[n+1];
memset(dp,0,sizeof(dp));
for(i=0;i<n;i++)
scanf("%d%d%d",&arr[i].first.first,&arr[i].first.second,&arr[i].second);
std::sort(arr,arr+n); // by default sorting on the basis of starting index
for(i=0;i<n;i++)
{
high=arr[i].second;
for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high , dp[j]+arr[i].second);
}
dp[i]=high;
}
for(i=0;i<n;i++)
dp[n-1]=std::max(dp[n-1],dp[i]);
printf("%d\n",dp[n-1]);
}
int main()
{solve();return 0;}
编辑: 我的工作代码最终花了我 3 个小时来调试它!此外,由于较大的常量和错误的实现,此代码比二进制搜索和排序代码慢:P(仅供参考)
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<cstring>
#include<iostream>
#include<climits>
#define lc(idx) (2*idx+1)
#define rc(idx) (2*idx+2)
#define mid(l,r) ((l+r)/2)
using namespace std;
int Tree[4*2*10000-1];
void update(int L,int R,int qe,int idx,int value)
{
if(value>Tree[0])
Tree[0]=value;
while(L<R)
{
if(qe<= mid(L,R))
{
idx=lc(idx);
R=mid(L,R);
}
else
{
idx=rc(idx);
L=mid(L,R)+1;
}
if(value>Tree[idx])
Tree[idx]=value;
}
return ;
}
int Get(int L,int R,int idx,int q)
{
if(q<L )
return 0;
if(R<=q)
return Tree[idx];
return max(Get(L,mid(L,R),lc(idx),q),Get(mid(L,R)+1,R,rc(idx),q));
}
bool cmp(pair < pair < int , int > , int > A,pair < pair < int , int > , int > B)
{
return A.first.second< B.first.second;
}
int main()
{
int N,i;
scanf("%d",&N);
pair < pair < int , int > , int > P[N];
vector < int > V;
for(i=0;i<N;i++)
{
scanf("%d%d%d",&P[i].first.first,&P[i].first.second,&P[i].second);
V.push_back(P[i].first.first);
V.push_back(P[i].first.second);
}
sort(V.begin(),V.end());
for(i=0;i<N;i++)
{
int &l=P[i].first.first,&r=P[i].first.second;
l=lower_bound(V.begin(),V.end(),l)-V.begin();
r=lower_bound(V.begin(),V.end(),r)-V.begin();
}
sort(P,P+N,cmp);
int ans=0;
memset(Tree,0,sizeof(Tree));
for(i=0;i<N;i++)
{
int aux=Get(0,2*N-1,0,P[i].first.first)+P[i].second;
if(aux>ans)
ans=aux;
update(0,2*N-1,P[i].first.second,0,ans);
}
printf("%d\n",ans);
return 0;
}
high=arr[i].second;
for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high, dp[j]+arr[i].second);
}
dp[i]=high;
这可以在 O(log n)
中用线段树完成。
首先,让我们重写一下。您取的最大值有点复杂,因为它取了涉及 i
和 j
的总和的最大值。但是i
这部分是常量,所以我们把它拿出来。
high=dp[0];
for(j=1;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
{
if(arr[i].first.first>=arr[j].first.second)
high=std::max(high, dp[j]);
}
dp[i]=high + arr[i].second;
很好,现在我们已将问题简化为从满足您的 if
条件的值中确定 [0, i - 1]
中的最大值。
如果我们没有if
,那将是线段树的简单应用。
现在有两个选择。
1.为线段树处理 O(log V)
查询时间和 O(V)
内存
其中 V
是间隔端点的最大大小。
您可以构建一个线段树,在您移动 i
时将区间起点插入其中。然后查询值的范围。像这样,线段树被初始化为 -infinity
并且大小为 O(V)
.
Update(node, index, value):
if node.associated_interval == [index, index]:
node.max = value
return
if index in node.left.associated_interval:
Update(node.left, index, value)
else:
Update(node.right, index, value)
node.max = max(node.left.max, node.right.max)
Query(node, left, right):
if [left, right] does not intersect node.associated_interval:
return -infinity
if node.associated_interval included in [left, right]:
return node.max
return max(Query(node.left, left, right),
Query(node.right, left, right))
[...]
high=Query(tree, 0, arr[i].first.first)
dp[i]=high + arr[i].second;
Update(tree, arr[i].first.first, dp[i])
2。减少到 O(log n)
查询时间和 O(n)
段树的内存
由于间隔的数量可能明显少于它们的长度,因此有理由认为我们可以以某种方式更好地对它们进行编码,因此它们的长度也是 O(n)
。确实,我们可以。
这涉及在 [1, 2*n]
范围内标准化您的间隔。考虑以下区间
8 100
3 50
90 92
让我们把它们画成一条直线。它们看起来像这样:
3 8 50 90 92 100
现在用它们的索引替换它们中的每一个:
1 2 3 4 5 6
3 8 50 90 92 100
并写下你的新间隔:
2 6
1 3
4 5
请注意,它们保留了初始间隔的属性:相同的重叠,相同的相互包含等。
这可以通过排序来完成。您现在可以应用相同的线段树算法,除了您为大小 2*n
.