通过 java 运行时使用 pickled python class

Using pickled python class through java runtime

我使用 scikit-learn 从一些字符串构建了一个二进制 SVM classifier:

count_vect = TfidfVectorizer(sublinear_tf=True, min_df=10, norm='l2', encoding='latin-1', ngram_range=(1, 4))
X_counts = count_vect.fit_transform(X)
tfidf_transformer = TfidfTransformer()
X_tfidf = tfidf_transformer.fit_transform(X_counts)
clf = SGDClassifier(loss='hinge', penalty='l2', tol=1e-3)
clf.fit(X_tfidf, Y)

然后我在经过训练的 classifier 周围创建了一个包装器 class 并将其腌制:

class Classifier:
    def __init__(self, clf, vect):
        self.classifier = clf
        self.vectorizer = vect

    def classify(self, s):
        return self.classifier.predict(self.vectorizer.transform([s]))[0]

with open('my_classifier.pkl', 'wb') as fout:
    pickle.dump(Classifier(clf, count_vect), fout)

在单独的脚本中,我可以加载 pickled classifier 并正确使用它:

with open('my_classifier.pkl', 'rb') as fin:
    clf = pickle.load(fin)

result = clf.classify(sys.argv[1])
print(result)

但是,当我尝试通过 java 运行时执行脚本时,它显示不正确的输出。

public boolean classify(String s) throws IOException {
        String cmd = "python3 pkl_classifier.py \"" + s + "\"";
        Process p = Runtime.getRuntime().exec(cmd);
        BufferedReader stdIn = new BufferedReader(new InputStreamReader(p.getInputStream()));

        String out = stdIn.readLine();
        if (out != null) {
            switch (Integer.parseInt(out)) {
                case 0: return false;
                case 1: return true;
                default: throw new RuntimeException("Error with classifier script:" + out);
            }
        }

        return false;
    }

classifier 的输出可以是 0 或 1。但是这个 java 代码总是产生 0。我已经从 java 并直接在终端中执行它并产生了正确的输出。但是 java 运行时产生的输出始终为 0。

有什么我遗漏的吗?

刚发现问题。

它在我通过 java 执行的脚本中。

假设我要 class 化的字符串是 "lorem ipsum dolor sit amet"。当我形成命令 class 在我的 java class 中将其定义为

String cmd = "python3 pkl_classifier.py \"" + s + "\"";

显示为

python3 pkl_classifier.py "lorem ipsum dolor sit amet"

这正是我要执行的命令。当我打印出变量 cmd 时,它看起来没问题。但是当我使用

在 python 脚本中打印命令时

print(sys.argv)

显示为

['pkl_classifier.py', '"lorem', 'ipsum', 'dolor', 'sit', 'amet"']

很明显它是通过忽略引号的空格分割命令。

然后我将脚本从

result = clf.classify(sys.argv[1])

result = clf.classify(' '.join(sys.argv[1:]))

效果很好!

更好的解决方案:

与其调整 python 脚本,不如使用 Process class 的另一个 exec() 方法,该方法将参数作为字符串数组。它将在内部处理参数中的空格。

String[] cmd = { "python3", "pkl_classifier.py", s };
p = Runtime.getRuntime().exec(cmd);