如何提高我相对简单的 Java 计数方法的效率 and/or 性能?
How can I improve the efficiency and/or performance of my relatively simple Java counting method?
我正在构建一个必须阅读大量文本文档的分类器,但我发现我的 countWordFrequenties
方法处理的文档越多,速度就越慢。下面的这个方法需要 60 毫秒(在我的 PC 上),而读取、规范化、标记化、更新我的词汇表和均衡不同的整数列表总共只需要 3-5 毫秒(在我的 PC 上)。我的countWordFrequencies
方法如下:
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
加快此过程的最佳方法是什么?这个方法有什么问题?
这是我的全部 Class,还有另一个 Class 类别,post 这个也在这里是个好主意还是你们不需要它?
public class BayesianClassifier
{
private Map<String,Integer> vocabularyWordFrequencies;
private List<String> vocabulary;
private List<Category> categories;
private List<Integer> wordFrequencies;
private int trainTextAmount;
private int testTextAmount;
private GUI gui;
public BayesianClassifier()
{
this.vocabulary = new ArrayList<>();
this.categories = new ArrayList<>();
this.wordFrequencies = new ArrayList<>();
this.trainTextAmount = 0;
this.gui = new GUI(this);
this.testTextAmount = 0;
}
public List<Category> getCategories()
{
return categories;
}
public List<String> getVocabulary()
{
return this.vocabulary;
}
public List<Integer> getWordFrequencies()
{
return wordFrequencies;
}
public int getTextAmount()
{
return testTextAmount + trainTextAmount;
}
public void updateWordFrequency(int index, Integer frequency)
{
equalizeIntList(wordFrequencies);
this.wordFrequencies.set(index, wordFrequencies.get(index) + frequency);
}
public String readText(String path)
{
BufferedReader br;
String result = "";
try
{
br = new BufferedReader(new FileReader(path));
StringBuilder sb = new StringBuilder();
String line = br.readLine();
while (line != null)
{
sb.append(line);
sb.append("\n");
line = br.readLine();
}
result = sb.toString();
br.close();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
public String normalizeText(String text)
{
String fstNormalized = Normalizer.normalize(text, Normalizer.Form.NFD);
fstNormalized = fstNormalized.replaceAll("[^\p{ASCII}]","");
fstNormalized = fstNormalized.toLowerCase();
fstNormalized = fstNormalized.replace("\n","");
fstNormalized = fstNormalized.replaceAll("[0-9]","");
fstNormalized = fstNormalized.replaceAll("[/()!?;:,.%-]","");
fstNormalized = fstNormalized.trim().replaceAll(" +", " ");
return fstNormalized;
}
public String[] handleText(String path)
{
String text = readText(path);
String normalizedText = normalizeText(text);
return tokenizeText(normalizedText);
}
public void createCategory(String name, BayesianClassifier bc)
{
Category newCategory = new Category(name, bc);
categories.add(newCategory);
}
public List<String> updateVocabulary(String[] tokens)
{
for (int i = 0; i < tokens.length; i++)
if (!vocabulary.contains(tokens[i]))
vocabulary.add(tokens[i]);
return vocabulary;
}
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
public String[] tokenizeText(String normalizedText)
{
return normalizedText.split(" ");
}
public void handleTrainDirectory(String folderPath, Category category)
{
File folder = new File(folderPath);
File[] listOfFiles = folder.listFiles();
if (listOfFiles != null)
{
for (File file : listOfFiles)
{
if (file.isFile())
{
handleTrainText(file.getPath(), category);
}
}
}
else
{
System.out.println("There are no files in the given folder" + " " + folderPath.toString());
}
}
public void handleTrainText(String path, Category category)
{
long startTime = System.currentTimeMillis();
trainTextAmount++;
String[] text = handleText(path);
updateVocabulary(text);
equalizeAllLists();
List<Integer> wordFrequencies = countWordFrequencies(text);
long finishTime = System.currentTimeMillis();
System.out.println("That took 1: " + (finishTime-startTime)+ " ms");
long startTime2 = System.currentTimeMillis();
category.update(wordFrequencies);
updatePriors();
long finishTime2 = System.currentTimeMillis();
System.out.println("That took 2: " + (finishTime2-startTime2)+ " ms");
}
public void handleTestText(String path)
{
testTextAmount++;
String[] text = handleText(path);
List<Integer> wordFrequencies = countWordFrequencies(text);
Category category = guessCategory(wordFrequencies);
boolean correct = gui.askFeedback(path, category);
if (correct)
{
category.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
else
{
Category correctCategory = gui.askCategory();
correctCategory.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
}
public void updatePriors()
{
for (Category category : categories)
{
category.updatePrior();
}
}
public Category guessCategory(List<Integer> wordFrequencies)
{
List<Double> chances = new ArrayList<>();
for (int i = 0; i < categories.size(); i++)
{
double chance = categories.get(i).getPrior();
System.out.println("The prior is:" + chance);
for(int j = 0; j < wordFrequencies.size(); j++)
{
chance = chance * categories.get(i).getWordProbabilities().get(j);
}
chances.add(chance);
}
double max = getMaxValue(chances);
int index = chances.indexOf(max);
System.out.println(max);
System.out.println(index);
return categories.get(index);
}
public double getMaxValue(List<Double> values)
{
Double max = 0.0;
for (Double dubbel : values)
{
if(dubbel > max)
{
max = dubbel;
}
}
return max;
}
public void equalizeAllLists()
{
for(Category category : categories)
{
if (category.getWordFrequencies().size() < vocabulary.size())
{
category.setWordFrequencies(equalizeIntList(category.getWordFrequencies()));
}
}
for(Category category : categories)
{
if (category.getWordProbabilities().size() < vocabulary.size())
{
category.setWordProbabilities(equalizeDoubleList(category.getWordProbabilities()));
}
}
}
public List<Integer> equalizeIntList(List<Integer> list)
{
while (list.size() < vocabulary.size())
{
list.add(0);
}
return list;
}
public List<Double> equalizeDoubleList(List<Double> list)
{
while (list.size() < vocabulary.size())
{
list.add(0.0);
}
return list;
}
public void selectFeatures()
{
for(int i = 0; i < wordFrequencies.size(); i++)
{
if(wordFrequencies.get(i) < 2)
{
vocabulary.remove(i);
wordFrequencies.remove(i);
for(Category category : categories)
{
category.removeFrequency(i);
}
}
}
}
}
您的方法有 O(n*m)
运行 时间(n 是词汇量大小,m 是标记大小)。通过散列,这可以减少到 O(m)
这显然更好。
for (String token: tokens) {
if(!map.containsKey(token)){
map.put(token,0);
}
map.put(token,map.get(token)+1);
}
我不会使用一个列表来存储词汇表,另一个列表来存储频率,而是使用一个 Map 来存储单词->频率。这样你就可以避免双重循环,在我看来,双重循环会破坏你的表现。
public Map<String,Integer> countWordFrequencies(String[] tokens) {
// vocabulary is Map<String,Integer> initialized with all words as keys and 0 as value
for (String word: tokens)
if (vocabulary.containsKey(word)) {
vocabulary.put(word, vocabulary.get(word)+1);
}
return vocabulary;
}
使用 Map
应该会显着提高性能,正如 Sleiman Jneidi 在他的回答中所建议的那样。但是,使用 Java 8 的流式 API 可以更优雅地完成此操作:
Map<String, Long> frequencies =
Arrays.stream(tokens)
.collect(Collectors.groupingBy(Function.identity(),
Collectors.counting()));
如果您不想使用 Java 8 个东西,您可以尝试使用 guavaMultiSet
我正在构建一个必须阅读大量文本文档的分类器,但我发现我的 countWordFrequenties
方法处理的文档越多,速度就越慢。下面的这个方法需要 60 毫秒(在我的 PC 上),而读取、规范化、标记化、更新我的词汇表和均衡不同的整数列表总共只需要 3-5 毫秒(在我的 PC 上)。我的countWordFrequencies
方法如下:
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
加快此过程的最佳方法是什么?这个方法有什么问题?
这是我的全部 Class,还有另一个 Class 类别,post 这个也在这里是个好主意还是你们不需要它?
public class BayesianClassifier
{
private Map<String,Integer> vocabularyWordFrequencies;
private List<String> vocabulary;
private List<Category> categories;
private List<Integer> wordFrequencies;
private int trainTextAmount;
private int testTextAmount;
private GUI gui;
public BayesianClassifier()
{
this.vocabulary = new ArrayList<>();
this.categories = new ArrayList<>();
this.wordFrequencies = new ArrayList<>();
this.trainTextAmount = 0;
this.gui = new GUI(this);
this.testTextAmount = 0;
}
public List<Category> getCategories()
{
return categories;
}
public List<String> getVocabulary()
{
return this.vocabulary;
}
public List<Integer> getWordFrequencies()
{
return wordFrequencies;
}
public int getTextAmount()
{
return testTextAmount + trainTextAmount;
}
public void updateWordFrequency(int index, Integer frequency)
{
equalizeIntList(wordFrequencies);
this.wordFrequencies.set(index, wordFrequencies.get(index) + frequency);
}
public String readText(String path)
{
BufferedReader br;
String result = "";
try
{
br = new BufferedReader(new FileReader(path));
StringBuilder sb = new StringBuilder();
String line = br.readLine();
while (line != null)
{
sb.append(line);
sb.append("\n");
line = br.readLine();
}
result = sb.toString();
br.close();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
public String normalizeText(String text)
{
String fstNormalized = Normalizer.normalize(text, Normalizer.Form.NFD);
fstNormalized = fstNormalized.replaceAll("[^\p{ASCII}]","");
fstNormalized = fstNormalized.toLowerCase();
fstNormalized = fstNormalized.replace("\n","");
fstNormalized = fstNormalized.replaceAll("[0-9]","");
fstNormalized = fstNormalized.replaceAll("[/()!?;:,.%-]","");
fstNormalized = fstNormalized.trim().replaceAll(" +", " ");
return fstNormalized;
}
public String[] handleText(String path)
{
String text = readText(path);
String normalizedText = normalizeText(text);
return tokenizeText(normalizedText);
}
public void createCategory(String name, BayesianClassifier bc)
{
Category newCategory = new Category(name, bc);
categories.add(newCategory);
}
public List<String> updateVocabulary(String[] tokens)
{
for (int i = 0; i < tokens.length; i++)
if (!vocabulary.contains(tokens[i]))
vocabulary.add(tokens[i]);
return vocabulary;
}
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
public String[] tokenizeText(String normalizedText)
{
return normalizedText.split(" ");
}
public void handleTrainDirectory(String folderPath, Category category)
{
File folder = new File(folderPath);
File[] listOfFiles = folder.listFiles();
if (listOfFiles != null)
{
for (File file : listOfFiles)
{
if (file.isFile())
{
handleTrainText(file.getPath(), category);
}
}
}
else
{
System.out.println("There are no files in the given folder" + " " + folderPath.toString());
}
}
public void handleTrainText(String path, Category category)
{
long startTime = System.currentTimeMillis();
trainTextAmount++;
String[] text = handleText(path);
updateVocabulary(text);
equalizeAllLists();
List<Integer> wordFrequencies = countWordFrequencies(text);
long finishTime = System.currentTimeMillis();
System.out.println("That took 1: " + (finishTime-startTime)+ " ms");
long startTime2 = System.currentTimeMillis();
category.update(wordFrequencies);
updatePriors();
long finishTime2 = System.currentTimeMillis();
System.out.println("That took 2: " + (finishTime2-startTime2)+ " ms");
}
public void handleTestText(String path)
{
testTextAmount++;
String[] text = handleText(path);
List<Integer> wordFrequencies = countWordFrequencies(text);
Category category = guessCategory(wordFrequencies);
boolean correct = gui.askFeedback(path, category);
if (correct)
{
category.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
else
{
Category correctCategory = gui.askCategory();
correctCategory.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
}
public void updatePriors()
{
for (Category category : categories)
{
category.updatePrior();
}
}
public Category guessCategory(List<Integer> wordFrequencies)
{
List<Double> chances = new ArrayList<>();
for (int i = 0; i < categories.size(); i++)
{
double chance = categories.get(i).getPrior();
System.out.println("The prior is:" + chance);
for(int j = 0; j < wordFrequencies.size(); j++)
{
chance = chance * categories.get(i).getWordProbabilities().get(j);
}
chances.add(chance);
}
double max = getMaxValue(chances);
int index = chances.indexOf(max);
System.out.println(max);
System.out.println(index);
return categories.get(index);
}
public double getMaxValue(List<Double> values)
{
Double max = 0.0;
for (Double dubbel : values)
{
if(dubbel > max)
{
max = dubbel;
}
}
return max;
}
public void equalizeAllLists()
{
for(Category category : categories)
{
if (category.getWordFrequencies().size() < vocabulary.size())
{
category.setWordFrequencies(equalizeIntList(category.getWordFrequencies()));
}
}
for(Category category : categories)
{
if (category.getWordProbabilities().size() < vocabulary.size())
{
category.setWordProbabilities(equalizeDoubleList(category.getWordProbabilities()));
}
}
}
public List<Integer> equalizeIntList(List<Integer> list)
{
while (list.size() < vocabulary.size())
{
list.add(0);
}
return list;
}
public List<Double> equalizeDoubleList(List<Double> list)
{
while (list.size() < vocabulary.size())
{
list.add(0.0);
}
return list;
}
public void selectFeatures()
{
for(int i = 0; i < wordFrequencies.size(); i++)
{
if(wordFrequencies.get(i) < 2)
{
vocabulary.remove(i);
wordFrequencies.remove(i);
for(Category category : categories)
{
category.removeFrequency(i);
}
}
}
}
}
您的方法有 O(n*m)
运行 时间(n 是词汇量大小,m 是标记大小)。通过散列,这可以减少到 O(m)
这显然更好。
for (String token: tokens) {
if(!map.containsKey(token)){
map.put(token,0);
}
map.put(token,map.get(token)+1);
}
我不会使用一个列表来存储词汇表,另一个列表来存储频率,而是使用一个 Map 来存储单词->频率。这样你就可以避免双重循环,在我看来,双重循环会破坏你的表现。
public Map<String,Integer> countWordFrequencies(String[] tokens) {
// vocabulary is Map<String,Integer> initialized with all words as keys and 0 as value
for (String word: tokens)
if (vocabulary.containsKey(word)) {
vocabulary.put(word, vocabulary.get(word)+1);
}
return vocabulary;
}
使用 Map
应该会显着提高性能,正如 Sleiman Jneidi 在他的回答中所建议的那样。但是,使用 Java 8 的流式 API 可以更优雅地完成此操作:
Map<String, Long> frequencies =
Arrays.stream(tokens)
.collect(Collectors.groupingBy(Function.identity(),
Collectors.counting()));
如果您不想使用 Java 8 个东西,您可以尝试使用 guavaMultiSet