S项,M桶加权选择算法

S items, M Buckets Weighted Selection Algorithm

我想从 M 个桶中抽取总共 S 个样本。每个桶都有一个权重 W,它描述了最终样本中给定桶中的项目的表示。例如,如果我有权重分别为 0.5、0.2 和 0.3 的桶 A、B 和 C,并且每个桶有足够多的样本,那么如果我的最终样本大小 S = 10,我希望我的样本包含 A 桶中的 5 个样本,B 桶中的 2 个样本,C 桶中的 3 个样本。当考虑到每个桶可能不包含根据权重和总样本大小计算出的所需样本数时,问题变得更加复杂。在这种情况下,需要调整其他权重,以便提供尽可能接近所需加权表示的样本。有人知道执行此操作的算法吗?

如果我理解正确的话,我会取每个结果的floor,然后分配余数。

让我们使用您的三个桶 A、B 和 C 的示例,权重分别为 0.5、0.2 和 0.3,但这次 S=13。

A: floor(13 * 0.5) = floor(6.5) = 6 
B: floor(13 * 0.2) = floor(2.6) = 2
C: floor(13 * 0.3) = floor(3.9) = 3

所以我们从 A 中的 6 个样本、B 中的 2 个样本和 C 中的 3 个样本开始,剩下 2 个样本。

为了选择用于放置剩余样本的桶,我们按容量的小数降序对桶进行排序。 A的分数容量为0.5,B为0.6,C为0.9,所以将剩余的样本分别加到C和B中,最终结果为

A: 6
B: 3
C: 4

我在Java中写了一个解决方案。由于舍入误差,它可能比要求的样本多 return 一两个样本,但这对我的应用程序来说没问题。如果您发现任何改进算法的方法,请随时 post 提出解决方案。

SampleNode.java

public abstract class SampleNode {
    protected double weight;

    protected abstract int getNumSamplesAvailable();
    protected abstract boolean hasSamples();
    protected abstract int takeAllSamples();
    protected abstract void sample(int target);
    public abstract boolean takeOneSample();
}

叶子SampleNode.java

public class LeafSampleNode extends SampleNode {
    private int numselected;
    private int numsamplesavailable;

    public LeafSampleNode(double weight, int numsamplesavailable) {
        this.weight = weight;
        this.numsamplesavailable = numsamplesavailable;
        this.numselected = 0;
    }

    protected void sample(int target) {
        if(target >= numsamplesavailable) {
            takeAllSamples();
        }
        else {
            numselected += target;
            numsamplesavailable -= target;
        }
    }

    @Override
    protected int getNumSamplesAvailable() {
        return numsamplesavailable;     
    }

    protected boolean hasSamples() {
        return numsamplesavailable > 0;
    }

    protected int getNumselected() {
        return numselected;
    }

    protected int takeAllSamples() {
        int samplestaken = numsamplesavailable;
        numselected += numsamplesavailable;
        numsamplesavailable = 0;
        return samplestaken;
    }
@Override
public boolean takeOneSample() {
    if(hasSamples()) {
        numsamplesavailable--;
        numselected++;
        return true;
    }
    return false;
}
}

根SampleNode.java:

import java.util.ArrayList;
import java.util.List;

public class RootSampleNode extends SampleNode {    
    private List<SampleNode> children;

    public RootSampleNode(double weight) {
        this.children = new ArrayList<SampleNode>();
        this.weight = weight;
    }

    public void selectSample(int target) {
        int totalsamples = getNumSamplesAvailable();
        if(totalsamples < target) { 
            //not enough samples to meet target, simply take everything
            for(int i = 0; i < children.size(); i++) {
                children.get(i).takeAllSamples();
            }
        }
        else {
            //there are enough samples to meet target, distribute to meet quotas as closely as possible
            sample(target);
        }
    }

    protected void sample(int target) {
        int samplestaken = 0;
        double totalweight = getTotalWeight(children);
        samplestaken +=  sample(totalweight, target, children);
        if(samplestaken < target) {
            sample(target - samplestaken);
        }
    }

    private int sample(double totalweight, int target, List<SampleNode> children) {
        int samplestaken = 0;
        for(int i = 0; i < children.size(); i++) {
            SampleNode child = children.get(i);
            if(child.hasSamples()) {
                int desired = (int) (target * (child.weight / totalweight) + 0.5);
                if(desired >= child.getNumSamplesAvailable()) {
                    samplestaken += child.takeAllSamples();
                }
                else {
                    child.sample(desired);
                    samplestaken += desired;
                }
            }           
        }
    if(samplestaken == 0) { //avoid deadlock / stack overflow...someone just take a sample
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).takeOneSample()) {
                samplestaken++;
                break;
            }   
        }
    }
        return samplestaken;
    }

@Override
public boolean takeOneSample() {
    if(hasSamples()) {
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).takeOneSample()) {
                return true;
            }
        }           
    }
    return false;
}

    protected double getTotalWeight(List<SampleNode> children) {
        double totalweight = 0;
        for(int i = 0; i < children.size(); i++) {
            SampleNode child = children.get(i);
            if(child.hasSamples()) {
                totalweight += child.weight;
            }
        }
        return totalweight;
    }

    protected boolean hasSamples() {
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).hasSamples()) {
                return true;
            }
        }
        return false;
    }

    protected int takeAllSamples() {
        int samplestaken = 0;
        for(int i = 0; i < children.size(); i++) {
            samplestaken += children.get(i).takeAllSamples();
        }
        return samplestaken;
    }

    protected int getNumSamplesAvailable() {
        int numsamplesavailable = 0;
        for(int i = 0; i < children.size(); i++) {
            numsamplesavailable += children.get(i).getNumSamplesAvailable();
        }
        return numsamplesavailable;
    }

    public void addChild(SampleNode sn) {
        this.children.add(sn);
    }
}

一些单元测试:

import static org.junit.Assert.assertTrue;

import org.junit.Test;

public class SampleNodeTest {

    @Test
    public void test1() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        root.selectSample(9);
        assertTrue(bucketA.getNumselected() == 5);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
    }

    @Test
    public void test2() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 5);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
    }

    @Test
    public void test3() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(0.5);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 4);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
        assertTrue(bucketD.getNumselected() == 3);
        assertTrue(bucketE.getNumselected() == 1);
    }

    @Test
    public void test4() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 3);
        assertTrue(bucketB.getNumselected() == 1);
        assertTrue(bucketC.getNumselected() == 2);
        assertTrue(bucketD.getNumselected() == 5);
        assertTrue(bucketE.getNumselected() == 2);
    }

    @Test
    public void test5() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(4);
        assertTrue(bucketA.getNumselected() == 1);
        assertTrue(bucketB.getNumselected() == 0);
        assertTrue(bucketC.getNumselected() == 1);
        assertTrue(bucketD.getNumselected() == 1);
        assertTrue(bucketE.getNumselected() == 1);
    }

    @Test
    public void test6() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(2);
        assertTrue(bucketA.getNumselected() == 1);
        assertTrue(bucketB.getNumselected() == 0);
        assertTrue(bucketC.getNumselected() == 0);
        assertTrue(bucketD.getNumselected() == 1);
        assertTrue(bucketE.getNumselected() == 0);
    }

    @Test
    public void test7() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 50);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 20);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 33);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 100);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 120);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(200);
        assertTrue(bucketA.getNumselected() == 50);
        assertTrue(bucketB.getNumselected() == 20);
        assertTrue(bucketC.getNumselected() == 30);
        assertTrue(bucketD.getNumselected() == 71);
        assertTrue(bucketE.getNumselected() == 29);
    }

    @Test
    public void test8() {
        RootSampleNode root = new RootSampleNode(1);
        RootSampleNode branch1 = new RootSampleNode(5);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 50);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 20);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 33);
        branch1.addChild(bucketA);
        branch1.addChild(bucketB);
        branch1.addChild(bucketC);
        RootSampleNode branch2 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 100);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 120);
        branch2.addChild(bucketD);
        branch2.addChild(bucketE);
        root.addChild(branch1);
        root.addChild(branch2);
        root.selectSample(200);
        assertTrue(bucketA.getNumselected() == 50);
        assertTrue(bucketB.getNumselected() == 20);
        assertTrue(bucketC.getNumselected() == 33);
        assertTrue(bucketD.getNumselected() == 70);
        assertTrue(bucketE.getNumselected() == 27);
    }
}

希望有人觉得这有用。

我的建议是编写一个循环,从当前重量与所需重量最远且非空的桶中采样。这是一些伪代码。显然,您可能希望将其推广到更多存储桶,但这应该可以让您了解。

set buckets[] = { // original items };
double weights[] = { 0.5, 0.2, 0.3}; // the desired weights
int counts[] = { 0, 0, 0 };  // number of items sampled so far

for (i = 0; i < n; i++) {
  double errors[] = { 0.0, 0.0, 0.0 };
  for (j = 0; j < 3; j++) {
    if (!empty(buckets[j]))
      errors[j] = abs(weights[j] - (counts[j] / n))
    else
      errors[j] = 0;
  }
  // choose the non-empty bucket whose current weight is 
  // furthest from the desired weight
  k = argmax(errors);
  sample(buckets[k]);  // take an item out of that bucket
  counts[k]++;         // increment count
}

如果您需要将它翻译成有效的 Java,我可能会被说服 :)。这将始终生成 n 个样本(假设至少有 n 个项目,否则它将对所有项目进行抽样),其分布尽可能接近所需的权重。