我正在使用MALLET来训练ParallelTopicModel。训练结束后,我找了一个主题推理器,取一个句子,通过推理器推理15次,然后检查结果。我发现,对于某些主题,每次的估计都不一样,而且根本不一致。
例如,对于20个主题,这是我对同一句话的估计主题概率的输出:
[0.004888044738437717, 0.2961123293878907, 0.0023192114841146965, 0.003828168015645214, 0.3838058036596986, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.08833215086412127, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.26812948669964976, 0.0023192114841146965, 0.0038281680156452146, 0.35582296097145744, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.09765976509353493, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.052283368409032215, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.2681294866996498, 0.0023192114841146965, 0.003828168015645214, 0.3931334178891125, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839043, 0.0019749423390935396, 0.002792447952547967, 0.018537939424381665, 0.09765976509353493, 0.03773855412711243, 0.007213888668919175, 0.0029028156321696105, 0.024300525720791197, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.2588018724702361, 0.0023192114841146965, 0.0038281680156452146, 0.3278401182832166, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.06967692240529397, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.2681294866996498, 0.0023192114841146965, 0.0038281680156452146, 0.5143924028714901, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.08833215086412126, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.014972911491377543, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.20283618709375414, 0.0023192114841146965, 0.0038281680156452146, 0.29985727559497544, 0.0023130490636768045, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.11631499355236223, 0.028410939897698752, 0.007213888668919175, 0.002902815632169611, 0.024300525720791197, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437716, 0.43602654282909553, 0.0023192114841146965, 0.0038281680156452146, 0.2998572755949755, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.09765976509353493, 0.03773855412711241, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.07224958788196291, 0.0023192114841146965, 0.0038281680156452146, 0.3278401182832165, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.08833215086412129, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.04295575417961857, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.2588018724702361, 0.0023192114841146965, 0.0038281680156452146, 0.4490991032655942, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.08833215086412127, 0.028410939897698755, 0.007213888668919175, 0.002902815632169611, 0.07093859686785953, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.24014664401140884, 0.0023192114841146965, 0.0038281680156452146, 0.26254681867732077, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.018537939424381665, 0.06967692240529395, 0.05639378258593975, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.2588018724702361, 0.0023192114841146965, 0.0038281680156452146, 0.3744781894302849, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.06967692240529398, 0.047066168356526085, 0.007213888668919175, 0.002902815632169611, 0.06161098263844586, 0.0085078656328731, 0.0071022047541209835, 0.012203497697416594]
[0.004888044738437717, 0.2681294866996498, 0.0023192114841146965, 0.0038281680156452146, 0.32784011828321646, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.08833215086412127, 0.03773855412711243, 0.007213888668919175, 0.002902815632169611, 0.024300525720791197, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.10956004479961755, 0.0023192114841146965, 0.0038281680156452146, 0.3838058036596989, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.018537939424381665, 0.11631499355236223, 0.03773855412711243, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437717, 0.25880187247023617, 0.0023192114841146965, 0.0038281680156452146, 0.28120204713614816, 0.002313049063676805, 0.007598391273477824, 0.019202573922429175, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.09765976509353493, 0.03773855412711241, 0.007213888668919175, 0.002902815632169611, 0.08959382532668683, 0.0085078656328731, 0.0071022047541209835, 0.0028758834680029416]
[0.004888044738437716, 0.2214914155525815, 0.0023192114841146965, 0.0038281680156452146, 0.37447818943028494, 0.002313049063676805, 0.007598391273477824, 0.009874959693015517, 0.0030322960351839047, 0.0019749423390935396, 0.002792447952547967, 0.01853793942438167, 0.07900453663470762, 0.03773855412711243, 0.007213888668919175, 0.002902815632169611, 0.03362813995020485, 0.0085078656328731, 0.007102204754120983, 0.0028758834680029416]
正如您所看到的,有几列非常不一致。为什么会这样,有没有办法防止这种情况发生?我正在使用分布作为另一个机器学习模型的特征,而这些不一致性会使我的另一个模型失败
我的代码:
ldaModel = new ParallelTopicModel(numTopics, alphaSum, beta);
instances = new InstanceList(new SerialPipes(pipeList));
for (int i = 0; i < data.length; i++) {
String dataPt = data[i];
Instance dataPtInstance = new Instance(dataPt, null, null, dataPt);
instances.addThruPipe(dataPtInstance);
}
ldaModel.addInstances(instances);
ldaModel.setNumThreads(numThreads);
ldaModel.setNumIterations(numIterations);
try {
ldaModel.setRandomSeed(DEFAULT_SEED);
ldaModel.estimate();
inferencer = ldaModel.getInferencer();
} catch (IOException e) {
System.out.println(e);
}
String dataPt = "This is a test sentence.";
Instance dataPtInstance = new Instance(dataPt, null, null, dataPt);
InstanceList testList = new InstanceList(new SerialPipes(pipeList));
testList.addThruPipe(dataPtInstance);
double[] prob = inferencer.getSampledDistribution(testList.get(0), testIterations, thinIterations, burnInIterations);
我想我明白了原因。由于吉布斯采样,用于估计的采样输出不能保证每次都是相同的。一种解决方法是进行0次采样迭代。
如果希望推理在多次运行中保持一致,还必须设置推理器的随机种子。
inferencer = ldaModel.getInferencer();
inferencer.setRandomSeed(DEFAULT_SEED);
此外,在训练模型时,请确保使用最新版本,因为随机种子初始化的错误大约在一年前得到了修复。