将浮点向量舍入到整数向量

Round Float Vector to Integer Vector

作为一个更大的模拟程序的一部分(我很乐意分享它的背景,但与问题无关),我 运行 遇到了以下问题并且正在寻找一个好算法。

问题: 给定一个长度为 n 的浮点数组 f(具有元素 f_1,...,f_n),指定 n 维 space 中的一个点。可以公平地假设 f 的总和为 0.0(取决于浮点精度)。

所寻求的: 求长度为n的整数数组i(元素为i_1, ..., i_n),在n维space中指定一个网格点,使得i的和恰好为0 和 d(f,i),f 和 i 之间距离的合适度量,被最小化。

至于合适的度量,我认为最好的度量是最小化相对误差的度量(即,对 (i_j/f_j-1)^2 的 j 求和),但最小化普通欧氏距离(即,对 (i_j-f_j)^2) 的 j 求和也可能有效。

我想到的最好的算法是在第 i 个网格(和为 0)上猜测一个合适的点,然后重复切换到具有最小距离的相邻网格点(和为 0),直到所有邻居都具有更大的距离。鉴于距离函数的凹性,这应该收敛于解决方案。

但该算法似乎很笨拙。谁能做得更好?

How to round floats to integers while preserving their sum?有相关讨论,但没有达到我要找的最优点。

上下文附录: 如果您对此感兴趣(也因为我认为它很酷),让我具体说明问题出现的背景。

该问题作为交易模拟的一部分出现(这是更大模拟的一部分)。在每个地点,代理商提供交易多种商品的服务。由于每个地点和商品都是分开处理的,所以我们可以集中在一个地点和商品上,依次处理。

每个代理人 j 都有一定数量的货币 c_j 和一定数量的商品 q_j,它们必须保持完整。每个代理人还指定了一个实值的、连续的、非负的、非单调递减的需求函数 d_j(p) ,它本质上代表了代理人在任何给定价格下想要拥有多少单位的商品。

交易按以下步骤执行。

  1. 对于每个代理人 j 计算预算约束需求函数 b_j(p) = min (d_j(p), q_j + c_j/p)
  2. 定义总量 q_tot(q_j 的 j 求和)和总预算约束需求函数 b_tot(p)(b_j 的 j 求和( p).
  3. 使用牛顿法求 b_tot(p_eq) = q_tot 的均衡价格 p_eq。如果没有均衡价格,return.
  4. 定义每个代理人的交易量f_j为b_j(p_eq)-q_j(净购买为正,净销售为负)。
  5. 从计算中删除 f_j 非常小(例如,小于 1/10)的代理人
  6. 每个q_j应调整为q_j+f_j,每个c_j应调整为c_j-p_eq*f_j
  7. 问题就出现在这个点上。 f_j 和 p_eq 是浮点数,但 q_j 和 c_j 需要保持整数。为避免创建或销毁货币或商品,数组 f_j 和 (-p_eq*f_j) 需要一致地舍入为整数

这是一个整数规划问题。分支定界法是一种简单的方法,在具有良好边界条件的情况下在实践中非常有效。

我实现了一个简单的分支限界算法。主要思想是为数组的每个成员尝试下一个更高和更低的整数。在每一步,我们先尝试产生较少损失的那个。一旦我们找到了一个潜在的解决方案,如果该解决方案比我们迄今为止找到的最好的解决方案更好,我们就会保留它,然后我们回溯。如果在任何时候我们发现我们的部分解决方案的损失比我们找到的最佳总损失更糟糕,那么我们可以修剪那个分支。

这是一个 Python 基本品牌绑定解决方案的实现。有很多方法可以进一步优化,但这显示了基本思想:

from sys import stderr, stdout
from math import floor, ceil
import random

def debug(msg):
  #stderr.write(msg)
  pass

def loss(pf,pi):
  return sum((pf[i]-pi[i])**2 for i in range(0,len(pf)))

class Solver:
  def __init__(self,pf):
    n = len(pf)
    self.pi = [0]*n
    self.pf = pf
    self.min_loss = n
    self.min_pi = None
    self.attempt_count = 0

  def test(self):
    """Test a full solution"""
    pi = self.pi
    pf = self.pf
    assert sum(pi)==0
    l = loss(pf,pi)
    debug('%s: %s\n'%(l,pi))
    if l<self.min_loss:
      self.min_loss = l
      self.min_pi = pi[:]

  def attempt(self,i,value):
    """Try adding value to the partial solution"""
    self.pi[i] = int(value)
    self.extend(i+1)
    self.attempt_count += 1

  def extend(self,i):
    """Extend the partial solution"""
    partial = self.pi[:i]
    loss_so_far = loss(self.pf[:i],partial)
    debug('%s: pi=%s\n'%(loss_so_far,partial))
    if loss_so_far>=self.min_loss:
      return
    if i==len(self.pf)-1:
      self.pi[i] = -sum(partial)
      self.test()
      return
    value = self.pf[i]
    d = value-floor(value)
    if d<0.5:
      # The the next lower integer first, since that causes less loss
      self.attempt(i,floor(value))
      self.attempt(i,ceil(value))
    else:
      # Try the next higher integer first
      self.attempt(i,ceil(value))
      self.attempt(i,floor(value))

def exampleInput(seed):
  random.seed(seed)
  n = 10
  p = [random.uniform(-100,100) for i in range(0,n)]
  average = sum(p)/n
  pf = [x-average for x in p]
  return pf

input = exampleInput(42)
stdout.write('input=%s\n'%input)
stdout.write('sum(input)=%s\n'%sum(input))

solver=Solver(input)
solver.extend(0)

stdout.write('best solution: %s\n'%solver.min_pi)
stdout.write('sum(best): %s\n'%sum(solver.min_pi))
stdout.write('attempts: %s\n'%solver.attempt_count)
stdout.write('loss: %s\n'%loss(input,solver.min_pi))
assert sum(solver.min_pi)==0

当 n = 2

这很简单。定义要使用以下算法计算的 ik。 ik 之和永远为 0,欧式距离永远最小。

if (f[k] < 0)
   i[k] = int(f[k]-0.5);
else
   i[k] = int(f[k]+0.5);

int() 是 returns 浮点数的整数部分的函数。它会截断浮点数。

说明

让我们定义 ek 使得 fk = ik + ek.

对于 n = 2,f0 = -f1。两个fk大小相同但符号相反。通过向 0 舍入,这两个误差也具有相同的大小但符号相反。因为 sum(f) = sum(i) + sum(e) 和 sum(e) 和 sum(f) 等于 0,sum(i) = 0.

除了平衡两个误差的大小外,通过四舍五入到最接近的整数,我们将误差最小化。总和(e2) 将是最小的。

当 n = 3

我们计算 ik 如上。然后 sum(i) 可能取值 -1、0 或 1。

当 sum(i) = -1 时,我们必须增加 ik 之一。选择ek最大的ik(ek都是正数)

当sum(i) = 1时,我们要减1k。选择ek最小的ik(ek都是负数)

说明

当 n=3 时,我们有 3 个错误值 |ek| < 0.5。结果 |sum(e)|永远不能是 2 或更多。由于sum(i)只能取整数值,sum(e)只能取值-1、0或1,sum(i)也一样。

|总和(e)| = 1 当所有 ek 具有相同符号时。这是因为 |ek| < 0.5。您总是需要三个相同符号的错误才能达到 1 或 -1。请注意,与它们具有不同符号且 sum(i) = 0.

的情况相比,这在统计上较少出现

我们如何决定选择哪个 ik

当 sum(i) = 1 时,sum(e) = -1 并且所有 ek 都是负数。 我们必须减一 ik。减少一个 ik 将通过增加其 ek 来平衡,因为 sum(i) + sum(e) = 0。因此我们应该选择ek 以便增加它,产生最小幅度的错误。这是最接近-0.5 的 ek,因此也是最小的 ek。这确保 sum(e2) 是最小的。

当 sum(i) = -1 时,同样的逻辑适用。在这种情况下,sum(e) = 1 和所有 ek 都是正数。增加一个 ik 通过减少其 ek 来平衡。通过选择最接近 0.5 的 ek,我们得到最小的总和 (e2).

当 n = 4

在这种情况下,sum(i) 仍然限于值 -1、0 和 1。

当sum(i) = -1时,取最大的ek.

当sum(i) = 1时,取最小的ek.

说明

|总和(e)|无法达到 2。这就是为什么 sum(i) 被限制为值 -1、0 和 1 的原因。

n = 3 的不同之处在于,现在我们可能有 3 或 4 个 ek 具有相同的符号。但是通过应用上述规则,sum(e2) 保持最小值。

泛化

当 n > 4 |sum(e)|可以大于1。在这种情况下,我们必须修改多个 ik.

一般算法如下

sum(i) -> m
when m = 0, 
    we are done.
when m < 0, 
    increment the m i<sub>k</sub> with biggest e<sub>k</sub>. 
when m > 0, 
    decrement the m i<sub>k</sub> with smallest e<sub>k</sub>.

这是 python 2.7 代码

def solve(pf):
    n = len(pf)
    
    # construct array pi from pf
    pi = [round(val) for val in pf]
    print "pi~:", pi
    
    # compute the sum of the integers
    m = sum(val for val in pi)
    print "m :", m
    
    # if the sum is zero, we are done
    if m == 0:
        return pi
        
    # compute the errors
    pe = [pf[k]-pi[k] for k in xrange(n)]
    print "pe :", pe

    # correct pi when m is negative    
    while m < 0:
        # locate the pi with biggest error
        biggest = 0
        for k in xrange(1,n):
            if pe[k] > pe[biggest]:
                biggest = k
                
        # adjust this integer i
        pi[biggest] += 1
        pe[biggest] -= 1
        m += 1
    
    # correct pi when m is positive    
    while m > 0:
        # locate the pi with smallest error
        smallest = 0
        for k in xrange(1,n):
            if pe[k] < pe[smallest]:
                smallest = k
        
        # adjust this integer i
        pi[smallest] -= 1
        pe[smallest] += 1
        m -= 1
   
    return pi
    
    
if __name__ == '__main__': 
    print "Example case when m is 0"    
    pf = [1.1, 2.2, -3.3]
    print "pf :", pf
    pi = solve( pf )
    print "pi :", pi
    
    print "Example case when m is 1"    
    pf = [0.6, 0.7, -1.3]
    print "pf :", pf
    pi = solve( pf )
    print "pi :", pi
    
    print "Example case when m is -1"    
    pf = [0.4, 1.4, -1.8]
    print "pf :", pf
    pi = solve( pf )
    print "pi :", pi