如何在Matlab中训练大样本集的神经网络



我正在尝试在大训练集上训练神经网络。

inputs由大约400万列128行组成,targets由62行组成。

hiddenLayerSize是128.

脚本如下:

net = patternnet(hiddenLayerSize);
net.inputs{1}.processFcns = {'removeconstantrows','mapminmax'};
net.outputs{2}.processFcns = {'removeconstantrows','mapminmax'};
net.divideFcn = 'dividerand';  % Divide data randomly
net.divideMode = 'sample';  % Divide up every sample
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
net.trainFcn = 'trainbfg';
net.performFcn = 'mse';  % Mean squared error
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
  'plotregression', 'plotfit'};
net.trainParam.show = 1;
net.trainParam.showCommandLine = 1;
[net,tr] = train(net,inputs,targets, 'showResources', 'yes', 'reduction', 10);

train开始执行时,Matlab挂起,Windows挂起或变慢,交换运行磁盘巨大,几十分钟没有其他事情发生。

计算机为12Gb Windows x64, Matlab也是64位。

进程管理器的内存使用情况在运行过程中变化。

除了简化列集还能做什么?

如果减少列车集,那么减少到哪个水平?除了尝试,怎么估计它的大小?

为什么函数不显示任何东西?

从远程诊断这类问题相当困难,以至于我甚至不确定任何人能回答的任何问题是否真的有帮助。此外,你同时问了几个问题,所以我会一步一步地回答。最后,我将尝试让您更好地理解脚本的内存消耗。

<标题> 内存消耗

数据集大小和副本

从你在内存中加载的数据集的大小开始,假设每个条目包含一个双浮点精度数,你的训练数据集需要(4e6 * 128 * 8) Bytes的内存,它大致解析为3.81 GB。如果我理解正确的话,你的输出数组包含(4e6 * 62)条目,这些条目变成了(4e6 * 62 * 8) Bytes,大致相当于1,15 GB。因此,即使在运行网络训练之前,您也会消耗大约5GB的内存。

现在是的MATLAB使用延迟复制所以任何赋值:

training = zeros(4e6, 128);
copy1 = training;
copy2 = training;

不需要新的内存。然而,任何切片操作:

training = zeros(4e6, 128);
part1 = training(1:1000, :);
part1 = training(1001:2000, :);

确实会分配更多内存。因此,在选择训练、验证和测试子集时:

net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;

内部train()函数可能会重新分配相同数量的内存量两次。你的总容量现在是10GB。如果您现在认为您的操作系统正在运行,以及一堆其他应用程序,那么很容易理解为什么所有程序突然变慢了。我可能在这里告诉你一些显而易见的事情,但是:你的数据集是非常大。

分析有助于

现在,虽然我很确定我的5 GB消耗计算,但我不确定这是否是一个有效的假设。底线是我不知道train()函数的内部工作。这就是为什么我敦促你用MATLAB自己的分析器来测试它。这确实会让你更好地理解函数调用和内存消耗。

减少内存使用

可以做些什么来减少内存消耗?这个问题可能从一开始就一直困扰着程序员。再一次,很难提供一个唯一的答案,因为解决方案往往取决于手头的任务、问题和工具。Matlab有一个关于如何减少内存使用的信息页面,让我们给它一个怀疑的好处。然而,问题往往在于要加载到内存中的数据的大小。

一方面,我当然会从减少数据集的大小开始。你真的需要4e6 * 128数据点吗?如果您这样做,那么您可能会考虑投资于专用解决方案,例如高性能服务器来执行计算。如果不是你,但只有你,必须查看你的数据集,并开始分析哪些特征可能是不必要的,以减少列,最重要的是,哪些样本可能是不必要的,以减少行。

<标题>

顺便说一句,您没有抱怨MATLAB的任何OutOfMemory错误,这可能是一个好兆头。也许你的机器只是挂起了,因为计算太密集了。这也是一个合理的假设,因为您正在创建一个具有128隐藏层,62输出和运行几个训练时代的网络,正如您应该做的那样。

Kill JVM

您可以做的是在没有Java环境(JVM)的情况下运行MATLAB来减少机器上的负载。这确保了MATLAB本身将需要更少的内存来运行。可以通过以下命令禁用JVM:

matlab -nojvm

如果您不需要显示任何图形,则此工作,因为MATLAB将在类似控制台的环境中运行。

最新更新