具有 mallet 主题建模的相同数据的不同主题分布
Different topic distributions for the same data with mallet topic modeling
我正在使用 Mallet topic modeling
并且我训练了一个模型。训练结束后,我立即打印训练集文档之一的主题分布并保存。然后,我尝试使用与测试集相同的文档,并通过相同的管道传递它,依此类推。但是我得到了一个完全不同的主题分布。训练后排名最高的主题概率约为 0.54,用作测试集时概率为 0.000。这是我的训练和测试代码:
public static ArrayList<Object> trainModel() throws IOException {
String fileName = "E:\Alltogether.txt";
String stopwords = "E:\stopwords-en.txt";
// Begin by importing documents from text to feature sequences
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
// Pipes: lowercase, tokenize, remove stopwords, map to features
pipeList.add(new CharSequenceLowercase());
pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\p{L}[\p{L}\p{P}]+\p{L}")));
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
pipeList.add(new TokenSequenceRemoveNonAlpha(true));
pipeList.add(new TokenSequence2FeatureSequence());
InstanceList instances = new InstanceList(new SerialPipes(pipeList));
Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\S*)[\s,]*(\S*)[\s,]*(.*)$"),
3, 2, 1)); // data, label, name fields
int numTopics = 75;
ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);
model.setOptimizeInterval(20);
model.addInstances(instances);
model.setNumThreads(2);
model.setNumIterations(2000);
model.estimate();
ArrayList<Object> results = new ArrayList<>();
results.add(model);
results.add(instances);
Alphabet dataAlphabet = instances.getDataAlphabet();
FeatureSequence tokens = (FeatureSequence) model.getData().get(66).instance.getData();
LabelSequence topics = model.getData().get(66).topicSequence;
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int position = 0; position < tokens.getLength(); position++) {
out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
}
System.out.println(out);
// Estimate the topic distribution of the 66th instance,
// given the current Gibbs state.
double[] topicDistribution = model.getTopicProbabilities(66);
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
for (int topic = 0; topic < numTopics; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 10) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
return results;
}
下面是测试部分:
private static void testModel(ArrayList<Object> results, String testDir) {
ParallelTopicModel model = (ParallelTopicModel) results.get(0);
InstanceList allTrainInstances = (InstanceList) results.get(1);
String stopwords = "E:\stopwords-en.txt";
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new CharSequenceLowercase());
pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\p{L}[\p{L}\p{P}]+\p{L}")));
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
pipeList.add(new TokenSequenceRemoveNonAlpha(true));
pipeList.add(new TokenSequence2FeatureSequence());
InstanceList instances = new InstanceList(new SerialPipes(pipeList));
Reader fileReader = null;
try {
fileReader = new InputStreamReader(new FileInputStream(new File(testDir)), "UTF-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\S*)[\s,]*(\S*)[\s,]*(.*)$"),
3, 2, 1)); // data, label, name fields
TopicInferencer inferencer = model.getInferencer();
inferencer.setRandomSeed(1);
double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
System.out.println(testProbabilities);
int index = getMaximum(testProbabilities);
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
Alphabet dataAlphabet = allTrainInstances.getDataAlphabet();
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int topic = 0; topic < 75; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%d\t%.3f\t", topic, testProbabilities[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 10) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
}
在行
double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
我可以简单地看到概率不同。同时,我尝试了不同的文件,但我总是得到与排名最高的主题相同的主题。感谢任何帮助。
如果有人遇到同样的问题,我会回答我自己的问题以备后用。
在 MALLET
的文档中说你应该使用相同的管道进行训练和测试。
我意识到 "new" 使用与训练步骤相同的管道确实 NOT 意味着使用相同的管道。您应该在训练模型时保存管道,并在测试时重新加载它们。我获取了 this question 的示例代码,现在可以使用了。
我正在使用 Mallet topic modeling
并且我训练了一个模型。训练结束后,我立即打印训练集文档之一的主题分布并保存。然后,我尝试使用与测试集相同的文档,并通过相同的管道传递它,依此类推。但是我得到了一个完全不同的主题分布。训练后排名最高的主题概率约为 0.54,用作测试集时概率为 0.000。这是我的训练和测试代码:
public static ArrayList<Object> trainModel() throws IOException {
String fileName = "E:\Alltogether.txt";
String stopwords = "E:\stopwords-en.txt";
// Begin by importing documents from text to feature sequences
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
// Pipes: lowercase, tokenize, remove stopwords, map to features
pipeList.add(new CharSequenceLowercase());
pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\p{L}[\p{L}\p{P}]+\p{L}")));
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
pipeList.add(new TokenSequenceRemoveNonAlpha(true));
pipeList.add(new TokenSequence2FeatureSequence());
InstanceList instances = new InstanceList(new SerialPipes(pipeList));
Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\S*)[\s,]*(\S*)[\s,]*(.*)$"),
3, 2, 1)); // data, label, name fields
int numTopics = 75;
ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);
model.setOptimizeInterval(20);
model.addInstances(instances);
model.setNumThreads(2);
model.setNumIterations(2000);
model.estimate();
ArrayList<Object> results = new ArrayList<>();
results.add(model);
results.add(instances);
Alphabet dataAlphabet = instances.getDataAlphabet();
FeatureSequence tokens = (FeatureSequence) model.getData().get(66).instance.getData();
LabelSequence topics = model.getData().get(66).topicSequence;
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int position = 0; position < tokens.getLength(); position++) {
out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
}
System.out.println(out);
// Estimate the topic distribution of the 66th instance,
// given the current Gibbs state.
double[] topicDistribution = model.getTopicProbabilities(66);
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
for (int topic = 0; topic < numTopics; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 10) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
return results;
}
下面是测试部分:
private static void testModel(ArrayList<Object> results, String testDir) {
ParallelTopicModel model = (ParallelTopicModel) results.get(0);
InstanceList allTrainInstances = (InstanceList) results.get(1);
String stopwords = "E:\stopwords-en.txt";
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new CharSequenceLowercase());
pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\p{L}[\p{L}\p{P}]+\p{L}")));
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
pipeList.add(new TokenSequenceRemoveNonAlpha(true));
pipeList.add(new TokenSequence2FeatureSequence());
InstanceList instances = new InstanceList(new SerialPipes(pipeList));
Reader fileReader = null;
try {
fileReader = new InputStreamReader(new FileInputStream(new File(testDir)), "UTF-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\S*)[\s,]*(\S*)[\s,]*(.*)$"),
3, 2, 1)); // data, label, name fields
TopicInferencer inferencer = model.getInferencer();
inferencer.setRandomSeed(1);
double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
System.out.println(testProbabilities);
int index = getMaximum(testProbabilities);
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
Alphabet dataAlphabet = allTrainInstances.getDataAlphabet();
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int topic = 0; topic < 75; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%d\t%.3f\t", topic, testProbabilities[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 10) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
}
在行
double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
我可以简单地看到概率不同。同时,我尝试了不同的文件,但我总是得到与排名最高的主题相同的主题。感谢任何帮助。
如果有人遇到同样的问题,我会回答我自己的问题以备后用。
在 MALLET
的文档中说你应该使用相同的管道进行训练和测试。
我意识到 "new" 使用与训练步骤相同的管道确实 NOT 意味着使用相同的管道。您应该在训练模型时保存管道,并在测试时重新加载它们。我获取了 this question 的示例代码,现在可以使用了。