vowpal wabbit java:获取原始预测

vowpal wabbit java: get raw predictions

我正在使用 vowpal wabbit 的 Java API 来进行预测。我需要原始预测(与 -r output.txt 相同)但我在 VWMulticlassLearner class 中找不到任何此类方法。我在下面使用 arg 通过 cmd -

在 python 中训练我的模型
vw -f model_filepath -c --cache_file cache_filepath -k --csoaa 40 -b 24 -q cd -q .... -q n: --ignore a --ignore x

我们在 Java 中使用以下代码来获得预测 -

VWLearners.create("-i ./data/train.model  -t --quiet"); // VWMulticlassLearner
VWLearners.create("-i ./data/train.model  -t --quiet --csoaa_ldf=mc --loss_function=logistic --probabilities"); //VWProbLearner
classes 的

None 具有 returns 原始预测的任何方法。

我想要与下面相同的预测 -

$ echo ' .. sample string .. ' | vw -i data/train.model -t -r test -p /dev/stdout
creating quadratic features for pairs: cd ce cu cw de du dw eu ew uw n:
ignoring namespaces beginning with: a x
only testing
predictions = /dev/stdout
raw predictions = test
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile =
num sources = 1
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
39
0.000000 0.000000            1            1.0    known       39      171

finished run
number of examples per pass = 1
passes used = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 171

$ cat test
0:1.05645 1:0.83437 2:-0.210798 3:-2.81048 4:-4.47558 5:-4.45883 6:-3.65177 7:-3.71191 8:-2.96008 9:-2.82846 10:-2.31816 11:0.925984 12:3.28547 13:5.20375 14:6.34244 15:6.13525 16:1.65726 17:1.22801 18:1.35034 19:3.27091 20:2.94066 21:-0.0276409 22:0.391437 23:1.267 24:-0.689573 25:0.0171876 26:3.12935 27:3.95045 28:3.86978 29:1.18468 30:0.0921049 31:0.436564 32:0.98946 33:1.00963 34:-0.265355 35:-3.02128 36:-2.52846 37:-2.8066 38:-3.50639 39:-4.6184

如何获取 Java 中文件 test 中的值作为方法响应?我不想读取文件以在 Java 中获得响应,这会很慢。

我最终使用了其中一个废弃的 PR。这是我的 git 补丁文件 -

diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
index 6b51c4d30..f3ccb6621 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
@@ -11,3 +11,17 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict(JNI
 JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline(JNIEnv *env, jobject obj, jobjectArray example_strings, jboolean learn, jlong vwPtr)
 { return base_predict<jint>(env, example_strings, learn, vwPtr, multiclass_predictor);
 }
+
+jfloatArray multiclass_raw_predictor(example* vec, JNIEnv *env){
+  size_t num_values = vec->l.cs.costs.size();
+  jfloatArray j_labels = env->NewFloatArray(num_values);
+  for (int i=0 ; i<num_values; i++) {
+    jfloat f[] = { vec->l.cs.costs[i].partial_prediction };
+    env->SetFloatArrayRegion(j_labels, i, 1, (float*)f);
+   }
+   return j_labels;
+ }
+
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict(JNIEnv *env, jobject obj, jstring example_string, jboolean learn, jlong vwPtr){
+return base_predict<jfloatArray>(env, example_string, learn, vwPtr, multiclass_raw_predictor);
+}
diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
index 05204d53e..5610704fa 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
@@ -24,6 +24,15 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict
 JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline
 (JNIEnv *, jobject, jobjectArray, jboolean, jlong);
 
+/*
+ * Class:     vowpalWabbit_learner_VWMulticlassLearner
+ * Method:    rawPredict
+ * Signature: ([Ljava/lang/String;ZJ)I
+ */
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict
+  (JNIEnv *, jobject, jstring, jboolean, jlong);
+
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
index b506cfb25..bb3156351 100644
--- a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
+++ b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
@@ -13,4 +13,25 @@ final public class VWMulticlassLearner extends VWIntLearner {
 
     @Override
     protected native int predictMultiline(String[] example, boolean learn, long nativePointer);
+
+    protected native float[] rawPredict(String example, boolean learn, long nativePointer);
+
+    /**
+     * Get raw prediction output.
+     *
+     * @param example a single vw example string
+     * @return Raw prediction
+     */
+
+    public float[] rawPredict(final String example) {
+        lock.lock();
+        try {
+            if (isOpen()) {
+                return rawPredict(example, false, nativePointer);
+            }
+            throw new IllegalStateException("Already closed.");
+        } finally {
+            lock.unlock();
+        }
+    }
 }