我发现机器学习教程/书籍/文章非常困难的是,当解释模型时(即使是用代码(,你只有在训练(和/或测试(模型之前才能得到代码。然后它停止了。我找不到从示例(例如主题建模(开始的教程/书籍,它们从数据集开始,训练模型并展示如何使用模型。在下面的代码中,我有一个新闻文章的数据集,存储在每个主题的文件夹中。使用Mallet,我可以创建模型(并保存它(,但它到此为止。
我现在如何使用它?我给模型写了一篇文章,作为输出,它符合主题。请不要参考Mallet文档,因为这也没有提供从一开始到使用模型的完整示例。
下面是一个取自《Java中的机器学习》(Bostjan Kaluza(一书的例子,其中提供了创建模型和保存/加载模型的代码。这对我来说是一个很好的起点,但如果我现在想使用这个经过训练的模型呢。有人能举一个Java的例子吗?它不一定要和马莱特在一起。
import cc.mallet.types.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.*;
import cc.mallet.topics.*;
import cc.mallet.util.Randoms;
import java.util.*;
import java.util.regex.*;
import java.io.*;
public class TopicModeling {
public static void main(String[] args) throws Exception {
String dataFolderPath = "data/bbc";
String stopListFilePath = "data/stoplists/en.txt";
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new Input2CharSequence("UTF-8"));
Pattern tokenPattern = Pattern.compile("[\p{L}\p{N}_]+");
pipeList.add(new CharSequence2TokenSequence(tokenPattern));
pipeList.add(new TokenSequenceLowercase());
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
pipeList.add(new TokenSequence2FeatureSequence());
pipeList.add(new Target2Label());
SerialPipes pipeline = new SerialPipes(pipeList);
FileIterator folderIterator = new FileIterator(
new File[] {new File(dataFolderPath)},
new TxtFilter(),
FileIterator.LAST_DIRECTORY);
// Construct a new instance list, passing it the pipe
// we want to use to process instances.
InstanceList instances = new InstanceList(pipeline);
// Now process each instance provided by the iterator.
instances.addThruPipe(folderIterator);
// Create a model with 100 topics, alpha_t = 0.01, beta_w = 0.01
// Note that the first parameter is passed as the sum over topics, while
// the second is the parameter for a single dimension of the Dirichlet prior.
int numTopics = 5;
ParallelTopicModel model = new ParallelTopicModel(numTopics, 0.01, 0.01);
model.addInstances(instances);
// Use two parallel samplers, which each look at one half the corpus and combine
// statistics after every iteration.
model.setNumThreads(4);
// Run the model for 50 iterations and stop (this is for testing only,
// for real applications, use 1000 to 2000 iterations)
model.setNumIterations(50);
model.estimate();
/*
* Saving model
*/
String modelPath = "myTopicModel";
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".model")));
oos.writeObject(model);
oos.close();
oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".pipeline")));
oos.writeObject(pipeline);
oos.close();
System.out.println("Model saved.");
/*
* Loading the model
*/
// ParallelTopicModel model;
// SerialPipes pipeline;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".model")));
model = (ParallelTopicModel) ois.readObject();
ois.close();
ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".pipeline")));
pipeline = (SerialPipes) ois.readObject();
ois.close();
System.out.println("Model loaded.");
// Show the words and topics in the first instance
// The data alphabet maps word IDs to strings
Alphabet dataAlphabet = instances.getDataAlphabet();
FeatureSequence tokens = (FeatureSequence) model.getData().get(0).instance.getData();
LabelSequence topics = model.getData().get(0).topicSequence;
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int position = 0; position < tokens.getLength(); position++) {
out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
}
System.out.println(out);
// Estimate the topic distribution of the first instance,
// given the current Gibbs state.
double[] topicDistribution = model.getTopicProbabilities(0);
// Get an array of sorted sets of word ID/count pairs
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
// Show top 5 words in topics with proportions for the first document
for (int topic = 0; topic < numTopics; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%dt%.3ft", topic, topicDistribution[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 5) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
/*
* Testing
*/
System.out.println("Evaluation");
// Split dataset
InstanceList[] instanceSplit= instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});
// Use the first 90% for training
model.addInstances(instanceSplit[0]);
model.setNumThreads(4);
model.setNumIterations(50);
model.estimate();
// Get estimator
MarginalProbEstimator estimator = model.getProbEstimator();
double loglike = estimator.evaluateLeftToRight(instanceSplit[1], 10, false, null);//System.out);
System.out.println("Total log likelihood: "+loglike);
}
}
/** This class illustrates how to build a simple file filter */
class TxtFilter implements FileFilter {
/** Test whether the string representation of the file
* ends with the correct extension. Note that {@ref FileIterator}
* will only call this filter if the file is not a directory,
* so we do not need to test that it is a file.
*/
public boolean accept(File file) {
return file.toString().endsWith(".txt");
}
}
我还发现ML包有时会忘记";生产模式";。也就是说,LDA最常见的用例是你有一个集合,并在上面进行训练。对于新文档的推理,你总是可以使用文档中描述的命令行,但如果你需要一个Java接口,你可能需要把一些例子放在一起。您包含的代码支持加载保存的模型,您只需要使用TopicInferencer
而不是MarginalProbabilityEstimator
。将getProbEstimator()
替换为getInferencer()
。TopicInferencer
的源具有处理实例的示例。可以使用pipeline
对象将文档字符串导入Mallet实例格式。它可能看起来像
Instance instance = pipeline.pipe(new Instance(inputText, null, null, null);
double[] distribution = inferencer.getSampledDistribution(instance, 10, 0, 5);
(我没有测试过(这些数字是估计后验概率的合理值,但它们也是粗略的猜测。