在使用matlab训练神经网络的过程中,我对性能图的输出有点问题。我使用lambda函数(plot_rain_loss(来绘制训练进度,它得到一个参数"信息">作为传递参数:
options = trainingOptions('sgdm', ...
'MaxEpochs',num_epochs,...
'InitialLearnRate',learning_rate, ...
'ValidationData',imdsTest, ...
'ValidationFrequency',validationFrequency, ...
'Verbose',false, ...
'Plots','none',...
'OutputFcn',@(info)plot_train_loss(info,3), ...
'MiniBatchSize', mini_batch_size);
[net,info] = trainNetwork(imdsTrain,my_net,options);
这也非常有效。至少是第一次。然而,一旦我想再次训练,函数就会记住持久函数变量(train_iteration、train_loss、train_accurcy、bestValAccuracy、valLag(,并且之前训练的图是下一次训练的初始输出。训练后,有没有办法清除绘图函数中隐含的持久变量?
以下是";plot_ train_loss";功能:
function stop = plot_train_loss(info,N)
% initialize persitent variable in order to save the past values
persistent train_iteration
persistent train_loss
persistent train_accuracy
% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag
global training_figure;
global num_iterations;
stop = false;
% check in which state the training is
% if start: do nothing
if info.State == "start"
info.Iteration = {};
info.TrainingLoss = {};
info.TrainingAccuracy = {};
bestValAccuracy = 0;
valLag = 0;
return
end
% assign values of training to variable
train_iteration(info.Iteration) = info.Iteration;
train_loss(info.Iteration) = info.TrainingLoss;
train_accuracy(info.Iteration) = info.TrainingAccuracy;
% plot training progress
t = tiledlayout(2,1);
nexttile;
plot(train_iteration,train_loss,'g');
xlim([0 num_iterations]);
max_train_loss = max(train_loss(:)) + 0.2;
ylim([0 max_train_loss]);
grid on;
title('Training loss');
nexttile;
plot(train_iteration,train_accuracy,'b');
xlim([0 num_iterations]);
ylim([0 100]);
grid on;
title('Training accuracy');
因此,函数是";"召回";在两种不同的情况下:
- 它在一个单一的训练周期内被召回:-->不清除变量
- 它在上一个训练周期之后被召回:-->清除变量
或在此方案中:(r:召回;|=下一周期(
<-----cycle 1----------><-----cycle 2----------->
<----------------------clear------------------- ->
r; r; r; r; r; r; r; r; | r; r; r; r; r; r; r; r;
希望有人能帮我解决这个问题,真的有点棘手
clear plot_train_loss
将从内存中清除函数及其所有持久变量。下次调用函数时,将从文件中读取并重新编译它,并重新创建它的持久变量。