如何有效地更新 EnumeratedDistribution 实例中的概率?
How to efficiently update probabilities within an EnumeratedDistribution instance?
问题总结
有没有什么方法可以在不创建全新实例的情况下更新 class EnumeratedIntegerDistribution 的现有实例中的概率?
背景
我正在尝试使用 android phone 实现简化的 Q 学习风格演示。我需要在学习过程中的每个循环中更新每个项目的概率。目前我无法从我的 enumeratedIntegerDistribution
实例中找到任何可以让我重置|更新|修改这些概率的方法。因此,我能看到的唯一方法是在每个循环中创建一个新的 EnumeratedIntegerDistribution 实例。请记住,每个循环只有 20 毫秒长,据我了解,与创建一个实例并更新现有实例中的值相比,这将是非常低效的内存。是否没有标准的集合式方法来更新这些概率?如果没有,是否有推荐的解决方法(即使用不同的 class、制作我自己的 class、覆盖某些内容以使其可访问等?)
跟进将是这个问题是否是一个没有实际意义的努力。通过尝试在每个循环中避免这个新实例,编译后的代码实际上是否 more/less 有效? (我的知识不够了解编译器将如何处理此类事情)。
代码
下面是一个最小的例子:
package com.example.mypackage.learning;
import android.app.Activity;
import android.os.Bundle;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
public class Qlearning extends Activity {
private int selectedAction;
private int[] actions = {0, 1, 2};
private double[] weights = {1.0, 1.0, 1.0};
private double[] qValues = {1.0, 1.0, 1.0};
private double qValuesSum;
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
private final double alpha = 0.001;
int action;
double reward;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
while(true){
action = determineAction();
reward = determineReward();
learn(action, reward);
}
}
public void learn(int action, double reward) {
qValues[selectedAction] = (alpha * reward) + ((1.0 - alpha) * qValues[selectedAction]);
qValuesSum = 0;
for (int i = 0; i < qValues.length; i++){
qValuesSum += Math.exp(qValues[i]);
}
weights[selectedAction] = Math.exp(qValues[selectedAction]) / qValuesSum;
// *** This seems inefficient ***
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
}
}
请不要关注 determineAction()
或 determineReward()
方法的缺失,因为这只是一个最小的示例。如果你想要一个工作示例,你可以很容易地在那里使用固定值(例如 1 和 1.5)。
此外,我很清楚无限 while 循环对于 GUI 来说会很麻烦,但同样,我只是想减少我必须在此处显示的代码以阐明要点。
编辑:
作为对评论的回应,我在下面发布了我对类似 class 的评论。请注意,我已经一年多没用过它了,可能会坏掉。仅供参考:
public class ActionDistribution{
private double reward = 0;
private double[] weights = {0.34, 0.34, 0.34};
private double[] qValues = {0.1, 0.1, 0.1};
private double learningRate = 0.1;
private double temperature = 1.0;
private int selectedAction;
public ActionDistribution(){}
public ActionDistribution(double[] weights, double[] qValues, double learningRate, double temperature){
this.weights = weights;
this.qValues = qValues;
this.learningRate = learningRate;
this.temperature = temperature;
}
public int actionSelect(){
double sumOfWeights = 0;
for (double weight: weights){
sumOfWeights = sumOfWeights + weight;
}
double randNum = Math.random() * sumOfWeights;
double selector = 0;
int iterator = -1;
while (selector < randNum){
try {
iterator++;
selector = selector + weights[iterator];
}catch (ArrayIndexOutOfBoundsException e){
Log.e("abcvlib", "weight index bound exceeded. randNum was greater than the sum of all weights. This can happen if the sum of all weights is less than 1.");
}
}
// Assigning this as a read-only value to pass between threads.
this.selectedAction = iterator;
// represents the action to be selected
return iterator;
}
public double[] getWeights(){
return weights;
}
public double[] getqValues(){
return qValues;
}
public double getQValue(int action){
return qValues[action];
}
public double getTemperature(){
return temperature;
}
public int getSelectedAction() {
return selectedAction;
}
public void setWeights(double[] weights) {
this.weights = weights;
}
public void setQValue(int action, double qValue) {
this.qValues[action] = qValue;
}
public void updateValues(double reward, int action){
double qValuePrev = getQValue(action);
// update qValues due to current reward
setQValue(action,(learningRate * reward) + ((1.0 - learningRate) * qValuePrev));
// update weights from new qValues
double qValuesSum = 0;
for (double qValue : getqValues()) {
qValuesSum += Math.exp(temperature * qValue);
}
// update weights
for (int i = 0; i < getWeights().length; i++){
getWeights()[i] = Math.exp(temperature * getqValues()[i]) / qValuesSum;
}
}
public double getReward() {
return reward;
}
public void setReward(double reward) {
this.reward = reward;
}
}
很遗憾,无法更新现有的 EnumeratedIntegerDistribution。我过去遇到过类似的问题,每次需要更新机会时我都重新创建了实例。
我不会太担心内存分配,因为它们是短期对象。这些是您不必担心的微优化。
在我的项目中,我确实实现了一种更简洁的方法,使用接口来创建这些 EnumeratedDistribution
class.
的实例
这不是直接的答案,但可能会引导您朝着正确的方向前进。
public class DistributedProbabilityGeneratorBuilder<T extends DistributedProbabilityGeneratorBuilder.ProbableItem> {
private static final DistributedProbabilityGenerator EMPTY = () -> {
throw new UnsupportedOperationException("Not supported");
};
private final Map<Integer, T> distribution = new HashMap<>();
private DistributedProbabilityGeneratorBuilder() {
}
public static <T extends ProbableItem> DistributedProbabilityGeneratorBuilder<T> newBuilder() {
return new DistributedProbabilityGeneratorBuilder<>();
}
public DistributedProbabilityGenerator build() {
return build(ProbableItem::getChances);
}
/**
* Returns a new instance of probability generator at every call.
* @param chanceChangeFunction - Function to modify existing chances
*/
public DistributedProbabilityGenerator build(Function<T, Double> chanceChangeFunction) {
if (distribution.isEmpty()) {
return EMPTY;
} else {
return new NonEmptyProbabilityGenerator(createPairList(chanceChangeFunction));
}
}
private List<Pair<Integer, Double>> createPairList(Function<T, Double> chanceChangeFunction) {
return distribution.entrySet().stream()
.map(entry -> Pair.create(entry.getKey(), chanceChangeFunction.apply(entry.getValue())))
.collect(Collectors.toList());
}
public DistributedProbabilityGeneratorBuilder<T> add(int id, T item) {
if (distribution.containsKey(id)) {
throw new IllegalArgumentException("Id " + id + " already present.");
}
this.distribution.put(id, item);
return this;
}
public interface ProbableItem {
double getChances();
}
public interface DistributedProbabilityGenerator {
int generateId();
}
public static class NonEmptyProbabilityGenerator implements DistributedProbabilityGenerator {
private final EnumeratedDistribution<Integer> enumeratedDistribution;
NonEmptyProbabilityGenerator(List<Pair<Integer, Double>> pairs) {
this.enumeratedDistribution = new EnumeratedDistribution<>(pairs);
}
@Override
public int generateId() {
return enumeratedDistribution.sample();
}
}
public static ProbableItem ofDouble(double chances) {
return () -> chances;
}
}
注意 - 我正在使用 EnumeratedDistribution<Integer>
。您可以轻松地将其更改为 EnumuratedIntegerDistribution
.
我上面class的使用方法如下
DistributedProbabilityGenerator distributedProbabilityGenerator = DistributedProbabilityGeneratorBuilder.newBuilder()
.add(0, ofDouble(10))
.add(1, ofDouble(45))
.add(2, ofDouble(45))
.build();
int generatedObjectId = distributedProbabilityGenerator.generateId();
同样,这不是对您问题的直接回答,而是更多地指导您如何更好地使用这些 classes。
问题总结
有没有什么方法可以在不创建全新实例的情况下更新 class EnumeratedIntegerDistribution 的现有实例中的概率?
背景
我正在尝试使用 android phone 实现简化的 Q 学习风格演示。我需要在学习过程中的每个循环中更新每个项目的概率。目前我无法从我的 enumeratedIntegerDistribution
实例中找到任何可以让我重置|更新|修改这些概率的方法。因此,我能看到的唯一方法是在每个循环中创建一个新的 EnumeratedIntegerDistribution 实例。请记住,每个循环只有 20 毫秒长,据我了解,与创建一个实例并更新现有实例中的值相比,这将是非常低效的内存。是否没有标准的集合式方法来更新这些概率?如果没有,是否有推荐的解决方法(即使用不同的 class、制作我自己的 class、覆盖某些内容以使其可访问等?)
跟进将是这个问题是否是一个没有实际意义的努力。通过尝试在每个循环中避免这个新实例,编译后的代码实际上是否 more/less 有效? (我的知识不够了解编译器将如何处理此类事情)。
代码
下面是一个最小的例子:
package com.example.mypackage.learning;
import android.app.Activity;
import android.os.Bundle;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
public class Qlearning extends Activity {
private int selectedAction;
private int[] actions = {0, 1, 2};
private double[] weights = {1.0, 1.0, 1.0};
private double[] qValues = {1.0, 1.0, 1.0};
private double qValuesSum;
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
private final double alpha = 0.001;
int action;
double reward;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
while(true){
action = determineAction();
reward = determineReward();
learn(action, reward);
}
}
public void learn(int action, double reward) {
qValues[selectedAction] = (alpha * reward) + ((1.0 - alpha) * qValues[selectedAction]);
qValuesSum = 0;
for (int i = 0; i < qValues.length; i++){
qValuesSum += Math.exp(qValues[i]);
}
weights[selectedAction] = Math.exp(qValues[selectedAction]) / qValuesSum;
// *** This seems inefficient ***
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
}
}
请不要关注 determineAction()
或 determineReward()
方法的缺失,因为这只是一个最小的示例。如果你想要一个工作示例,你可以很容易地在那里使用固定值(例如 1 和 1.5)。
此外,我很清楚无限 while 循环对于 GUI 来说会很麻烦,但同样,我只是想减少我必须在此处显示的代码以阐明要点。
编辑:
作为对评论的回应,我在下面发布了我对类似 class 的评论。请注意,我已经一年多没用过它了,可能会坏掉。仅供参考:
public class ActionDistribution{
private double reward = 0;
private double[] weights = {0.34, 0.34, 0.34};
private double[] qValues = {0.1, 0.1, 0.1};
private double learningRate = 0.1;
private double temperature = 1.0;
private int selectedAction;
public ActionDistribution(){}
public ActionDistribution(double[] weights, double[] qValues, double learningRate, double temperature){
this.weights = weights;
this.qValues = qValues;
this.learningRate = learningRate;
this.temperature = temperature;
}
public int actionSelect(){
double sumOfWeights = 0;
for (double weight: weights){
sumOfWeights = sumOfWeights + weight;
}
double randNum = Math.random() * sumOfWeights;
double selector = 0;
int iterator = -1;
while (selector < randNum){
try {
iterator++;
selector = selector + weights[iterator];
}catch (ArrayIndexOutOfBoundsException e){
Log.e("abcvlib", "weight index bound exceeded. randNum was greater than the sum of all weights. This can happen if the sum of all weights is less than 1.");
}
}
// Assigning this as a read-only value to pass between threads.
this.selectedAction = iterator;
// represents the action to be selected
return iterator;
}
public double[] getWeights(){
return weights;
}
public double[] getqValues(){
return qValues;
}
public double getQValue(int action){
return qValues[action];
}
public double getTemperature(){
return temperature;
}
public int getSelectedAction() {
return selectedAction;
}
public void setWeights(double[] weights) {
this.weights = weights;
}
public void setQValue(int action, double qValue) {
this.qValues[action] = qValue;
}
public void updateValues(double reward, int action){
double qValuePrev = getQValue(action);
// update qValues due to current reward
setQValue(action,(learningRate * reward) + ((1.0 - learningRate) * qValuePrev));
// update weights from new qValues
double qValuesSum = 0;
for (double qValue : getqValues()) {
qValuesSum += Math.exp(temperature * qValue);
}
// update weights
for (int i = 0; i < getWeights().length; i++){
getWeights()[i] = Math.exp(temperature * getqValues()[i]) / qValuesSum;
}
}
public double getReward() {
return reward;
}
public void setReward(double reward) {
this.reward = reward;
}
}
很遗憾,无法更新现有的 EnumeratedIntegerDistribution。我过去遇到过类似的问题,每次需要更新机会时我都重新创建了实例。
我不会太担心内存分配,因为它们是短期对象。这些是您不必担心的微优化。
在我的项目中,我确实实现了一种更简洁的方法,使用接口来创建这些 EnumeratedDistribution
class.
这不是直接的答案,但可能会引导您朝着正确的方向前进。
public class DistributedProbabilityGeneratorBuilder<T extends DistributedProbabilityGeneratorBuilder.ProbableItem> {
private static final DistributedProbabilityGenerator EMPTY = () -> {
throw new UnsupportedOperationException("Not supported");
};
private final Map<Integer, T> distribution = new HashMap<>();
private DistributedProbabilityGeneratorBuilder() {
}
public static <T extends ProbableItem> DistributedProbabilityGeneratorBuilder<T> newBuilder() {
return new DistributedProbabilityGeneratorBuilder<>();
}
public DistributedProbabilityGenerator build() {
return build(ProbableItem::getChances);
}
/**
* Returns a new instance of probability generator at every call.
* @param chanceChangeFunction - Function to modify existing chances
*/
public DistributedProbabilityGenerator build(Function<T, Double> chanceChangeFunction) {
if (distribution.isEmpty()) {
return EMPTY;
} else {
return new NonEmptyProbabilityGenerator(createPairList(chanceChangeFunction));
}
}
private List<Pair<Integer, Double>> createPairList(Function<T, Double> chanceChangeFunction) {
return distribution.entrySet().stream()
.map(entry -> Pair.create(entry.getKey(), chanceChangeFunction.apply(entry.getValue())))
.collect(Collectors.toList());
}
public DistributedProbabilityGeneratorBuilder<T> add(int id, T item) {
if (distribution.containsKey(id)) {
throw new IllegalArgumentException("Id " + id + " already present.");
}
this.distribution.put(id, item);
return this;
}
public interface ProbableItem {
double getChances();
}
public interface DistributedProbabilityGenerator {
int generateId();
}
public static class NonEmptyProbabilityGenerator implements DistributedProbabilityGenerator {
private final EnumeratedDistribution<Integer> enumeratedDistribution;
NonEmptyProbabilityGenerator(List<Pair<Integer, Double>> pairs) {
this.enumeratedDistribution = new EnumeratedDistribution<>(pairs);
}
@Override
public int generateId() {
return enumeratedDistribution.sample();
}
}
public static ProbableItem ofDouble(double chances) {
return () -> chances;
}
}
注意 - 我正在使用 EnumeratedDistribution<Integer>
。您可以轻松地将其更改为 EnumuratedIntegerDistribution
.
我上面class的使用方法如下
DistributedProbabilityGenerator distributedProbabilityGenerator = DistributedProbabilityGeneratorBuilder.newBuilder()
.add(0, ofDouble(10))
.add(1, ofDouble(45))
.add(2, ofDouble(45))
.build();
int generatedObjectId = distributedProbabilityGenerator.generateId();
同样,这不是对您问题的直接回答,而是更多地指导您如何更好地使用这些 classes。