优化:具有最大值的受限整数分区

Optimize: Restricted integer partioning with max value

用下面的代码,我统计每个分区中有k个数字的限制整数分区(每个数字在每个分区中只能出现一次),每个数字等于或大于1且不大于 m。此代码会生成大量缓存值,因此会很快耗尽内存。

示例:

sum := 15, k := 4, m:= 10 预期结果是 6

具有以下受限整数分区:

1,2,3,9,1,2,4,8,1,2,5,7,1,3,4,7,1,3,5,7,2,3,4,6

public class Key{
  private final int sum;
  private final short k1;
  private final short start;
  private final short end;

  public Key(int sum, short k1, short start, short end){
    this.sum = sum;
    this.k1 = k1;
    this.start = start;
    this.end = end;
  }
  // + hashcode and equals
}

public BigInteger calcRestrictedIntegerPartitions(int sum,short k,short m){
  return calcRestrictedIntegerPartitionsHelper(sum,(short)0,k,(short)1,m,new HashMap<>());
}

private BigInteger calcRestrictedIntegerPartitionsHelper(int sum, short k1, short k, short start, short end, Map<Key,BigInteger> cache){
  if(sum < 0){
    return BigInteger.ZERO;
  }
  if(k1 == k){
    if(sum ==0){
      return BigInteger.ONE;
    }
    return BigInteger.ZERO;
  }
  if(end*(k-k1) < sum){
    return BigInteger.ZERO;
  }

  final Key key = new Key(sum,(short)(k-k1),start,end);

  BigInteger fetched = cache.get(key);

  if(fetched == null){
    BigInteger tmp = BigInteger.ZERO;

    for(short i=start; i <= end;i++){
      tmp = tmp.add(calcRestrictedIntegerPartitionsHelper(sum-i,(short)(k1+1),k,(short)(i+1),end,cache));
    }

    cache.put(key, tmp);
    return tmp;
  }

  return fetched;
}

avoid/reduce缓存有公式吗?或者我如何用 k and m?

计算受限制的整数部分

您的密钥包含 4 个部分,因此散列 space 可能会达到这些部分最大值的乘积值。使用向后循环和零值作为自然限制可以将键减少到 3 个部分。

Python 示例使用 in-built 功能 lru_cache,哈希表大小 = N*K*M

@functools.lru_cache(250000)
def diff_partition(N, K, M):
    '''Counts integer partitions of N with K distint parts <= M'''
    if K == 0:
        if N == 0:
            return 1
        return 0
    res = 0
    for i in range(min(N, M), -1, -1):
        res += diff_partition(N - i, K - 1, i - 1)
    return res

def diffparts(Sum, K, M):   #diminish problem size allowing zero part
    return diff_partition(Sum - K, K, M-1)

print(diffparts(500, 25, 200))

>>>147151784574

你的问题可以转移,所以你只需要缓存中的 3 个键和更少的启动运行时间。更少的不同键意味着更好的缓存(比我聪明的人可能仍然会找到更便宜的解决方案)。

让我们将分区视为集合。每组的元素应排序(升序)。 当您将 sum := 15, k := 4, m:= 10 的预期结果声明为 [1, 2, 3, 9]; [1, 2, 4, 8] ....

时,您已经隐含地完成了此操作

您为分区定义的限制是:

  • 每组 k 个元素
  • max m 作为元素
  • 不同的值
  • non-zero 正整数

区分的限制其实有点麻烦,所以我们解除它。 为此,我们需要稍微改变一下问题。因为你的集合的元素是升序的(并且是不同的),我们知道,每个元素的最小值是一个升序序列(如果我们忽略总和必须是sum),所以最小值是:[1, 2, 3, ...]。 例如,如果 m 小于 k,则可能的分区数将始终为零。同样,如果 [1, 2, 3, ... k] 的总和大于 sum,那么结果也为零。我们在一开始就排除了这些边缘情况,以确保转换是合法的。

让我们看一下 'legal partition' 的几何表示以及我们要如何对其进行变换。我们有 k 列,m 行和 sum 方块填充蓝色(浅蓝色或深蓝色)。

红色和深蓝色方块是无关紧要的,正如我们已经知道的那样,深蓝色方块必须始终被填充,而红色方块必须始终为空。因此,我们可以将它们从我们的计算中排除,并在我们进行时假设它们各自的状态。结果框显示在右侧。每列的位置都是 'shifted down',红色和深蓝色区域被切断。 我们现在有一个更小的整体框,一列现在可以是空的(我们可能在相邻列中有相同数量的蓝色框)。

从算法上讲,转换现在是这样进行的: 对于合法分区中的每个元素,我们减去它的位置(从 1 开始)。因此对于 [1, 2, 4, 8] 我们得到 [0, 0, 1, 4]。此外,我们必须相应地调整边界(summ):

// from the sum, we subtract the sum of [1, 2, 3, ... k], which is (k * (k + 1) / 2)
sum_2 = sum - (k * (k + 1) / 2)

// from m we subtract the maximum position (which is k)
m_2 = m - k

现在我们已经将分区问题转换为另一个分区问题,没有具有不同元素的限制!此外,此分区可以包含元素 0,而我们的原始分区不能包含该元素。 (我们保持内部升序)。

现在我们需要稍微改进一下递归。如果我们知道元素是递增的,不一定是不同的并且总是 less-equal 到 m_2,那么我们就已经将可能的元素绑定到一个范围内。示例:

[0, 1, 3, n1, n2]
=> 3 <= n1 <= m_2
=> 3 <= n2 <= m_2

因为我们知道示例中的 n1n23 或更大,所以在调用递归时,我们也可以将它们都减少 3并将sum_2减少2 * 3(一个是'open'个元素的个数,一个是最后一个'fixed'个元素的值)。这样,我们在递归中传递的就没有上界和下界了,只有只有一个上界,也就是我们之前的(m).

因此,我们可以抛出您的缓存键的 1 个值:start。相反,在解决这个简化的问题时,我们现在只有 3 个:summk

以下实现可达到此效果:

@Test
public void test() {
    calcNumRIPdistinctElementsSpecificKmaxM(600, (short) 25, (short) 200);
}

public BigInteger calcNumRIPdistinctElementsSpecificKmaxM(int sum, short k, short m) {
    // If the biggest allowed number in a partition is less than the number of parts, then
    // they cannot all be distinct, therefore we have zero results.
    if (m < k) {
        return BigInteger.ZERO;
    }
    
    // If the sum of minimum element-values for k is less than the expected sum, then
    // we also have no results.
    final int v = ((k * ((int) k + 1)) / 2);
    if (sum < v) {
        return BigInteger.ZERO;
    }
    
    // We normalize the problem by lifting the distinction restriction.
    final Cache cache = new Cache();
    final int sumNorm = sum - v;
    final short mNorm = (short) (m - k);
    
    BigInteger result = calcNumRIPspecificKmaxM(sumNorm, k, mNorm, cache);

    System.out.println("Calculation (n=" + sum + ", k=" + k + ", m=" + m + ")");
    System.out.println("p = " + result);
    System.out.println("entries = " + cache.getNumEntries());
    System.out.println("c-rate = " + cache.getCacheRate());
    
    return result;
}

public BigInteger calcNumRIPspecificKmaxM(int sum, short k, short m, Cache cache) {
    
    // We can improve cache use by standing the k*m-rectangle upright (k being the 'bottom').
    if (k > m) {
        final short c = k;
        k = m;
        m = c;
    }
    
    // If the result is trivial, we just calculate it. This is true for k < 3
    if (k < 3) {
        if (k == 0) {
            return sum == 0 ? BigInteger.ONE : BigInteger.ZERO;
            
        } else if (k == 1) {
            return sum <= m ? BigInteger.ONE : BigInteger.ZERO;
            
        } else {
            final int upper = Math.min(sum, m);
            final int lower = sum - upper;
            
            if (upper < lower) {
                return BigInteger.ZERO;
            }
            
            final int difference = upper - lower;
            final int numSubParts = difference / 2 + 1;
            return BigInteger.valueOf(numSubParts);
        }
    }
    
    // If k * m / 2 < sum, we can 'invert' the sub problem to reduce the number of keys further.
    sum = Math.min(sum, k * m - sum);
    
    // If the sum is less than m and maybe even k, we can reduce the box. This improves the cache size even further.
    if (sum < m) {
        m = (short) sum;
        
        if (sum < k) {
            k = (short) sum;

            if (k < 3) {
                return calcNumRIPspecificKmaxM(sum, k, m, cache);
            }
        }
    }
    
    // If the result is non-trivial, we check the cache or delegate.
    final Triple<Short, Short, Integer> key = Triple.of(k, m, sum);
    final BigInteger cachedResult = cache.lookUp(key);
    if (cachedResult != null) {
        return cachedResult;
    }
    
    BigInteger current = BigInteger.ZERO;
    
    // i = m is reached in case the result is an ascending stair e.g. [1, 2, 3, 4]
    for (int i = 0; i <= m; ++i) {
        final int currentSum = sum - (i * k);
        if (currentSum < 0) {
            break;
        }
        
        short currentK = (short) (k - 1);
        short currentM = (short) (m - i);
        
        current = current.add(calcNumRIPspecificKmaxM(currentSum, currentK, currentM, cache));
    }
    
    // We cache this new result and return it.
    cache.enter(key, current);
    return current;
}

public static class Cache {
    private final HashMap<Triple<Short, Short, Integer>, BigInteger> map = new HashMap<>(1024);
    private long numLookUps = 0;
    private long numReuse = 0;
    
    public BigInteger lookUp(Triple<Short, Short, Integer> key) {
        ++numLookUps;
        
        BigInteger value = map.get(key);
        if (value != null) {
            ++numReuse;
        }
        
        return value;
    }
    
    public void enter(Triple<Short, Short, Integer> key, BigInteger value) {
        map.put(key, value);
    }
    
    public double getCacheRate() {
        return (double) numReuse / map.size();
    }
    
    public int getNumEntries() {
        return map.size();
    }
    
    public long numLookUps() {
        return numLookUps;
    }
    
    public long getNumReuse() {
        return numReuse;
    }
}

注意:我在这里使用 apache-common 的 Triple-class 作为键,以节省显式 key-class 的实现,但这不是优化在运行时,它只是保存代码。

编辑:除了修复@MBo(谢谢)发现的问题外,我还添加了一些快捷方式来达到相同的结果。该算法现在性能更好,缓存(重用)率也更高。也许这会满足您的要求?

优化说明(仅适用于上述问题的转置):

  • 如果k > m,我们可以'flip'将矩形竖直,对于合法的分区数,仍然得到相同的结果。这会将一些 'lying' 配置映射到 'upright' 配置并减少不同密钥的总量。

  • 如果矩形中的方格数大于'empty spaces',我们可以将'empty spaces'视为正方形,这将映射另一串键。

  • 如果sum < kand/orsum < m,我们可以将kand/orm归约求和,还是得到相同的个数分区。 (这是影响最大的优化,因为它经常跳过多个冗余的中间步骤并经常达到 m = k = sum

另一种方法是使用约束求解器并将其配置为显示所有解决方案。这是一个使用 MiniZinc 的解决方案:

include "globals.mzn";

int: sum = 15;
int: k = 4;
int: m = 10;

array[1..k] of var 1..m: numbers;

constraint sum(numbers) = sum;

constraint alldifferent(numbers);

constraint increasing(numbers);

solve satisfy;