逻辑回归中计算成本的麻烦



我正在从Andrew Ng上获得Coursera的机器学习课程。在此评估中,我正在使用MATLAB中的Logistic回归来计算成本函数,但是正在使用SFMINBX(第27行(接收"错误"(第27行(目标函数在初始点不确定。fminunc不能继续。"。

我应该补充说,下面的成本功能功能内的成本j为NAN,因为日志(sigmoid(x * theta((是-inf向量。我确定这与例外有关。你能帮忙吗?

我的成本功能看起来如下:

function [J, grad] = costFunction(theta, X, y)
  m = length(y); % number of training examples
  J = 0;
  grad = zeros(size(theta));
  h = sigmoid(theta * X);
  J    = - (1 / m) * ((log(h)' * y) + (log(1 - h)' * (1 - y)));
  grad = (1 / m) * X' * (h - y);
end

我调用此功能的代码如下:

data = load('ex2data1.txt');
X = data(:, [1, 2]); y = data(:, 3);
[m, n] = size(X);
% Add intercept term to x and X_test
X = [ones(m, 1) X];
% Initialize fitting parameters
initial_theta = zeros(n + 1, 1);
% Compute and display initial cost and gradient
[cost, grad] = costFunction(initial_theta, X, y);
fprintf('Cost at initial theta (zeros): %fn', cost);
fprintf('Expected cost (approx): 0.693n');
fprintf('Gradient at initial theta (zeros): n');
fprintf(' %f n', grad);
fprintf('Expected gradients (approx):n -0.1000n -12.0092n -11.2628n');
% Compute and display cost and gradient with non-zero theta
test_theta = [-24; 0.2; 0.2];
[cost, grad] = costFunction(test_theta, X, y);
fprintf('nCost at test theta: %fn', cost);
fprintf('Expected cost (approx): 0.218n');
fprintf('Gradient at test theta: n');
fprintf(' %f n', grad);
fprintf('Expected gradients (approx):n 0.043n 2.566n 2.647n');
fprintf('nProgram paused. Press enter to continue.n');
pause;

%% ============= Part 3: Optimizing using fminunc  =============
%  In this exercise, you will use a built-in function (fminunc) to find the
%  optimal parameters theta.
%  Set options for fminunc
options = optimset('GradObj', 'on', 'MaxIter', 400, 'Algorithm', 'trust-
region');
%  Run fminunc to obtain the optimal theta
%  This function will return theta and the cost 
[theta, cost] = ...
    fminunc(@(t)(costFunction(t, X, y)), initial_theta, options);
end

数据集看起来如下:

34.62365962451697,78.0246928153624,030.28671076822607,43.89499752400101,035.84740876993872,72.90219802708364,060.18259938620976,86.30855209546826,179.0327360507101,75.3443764369103,145.0832774766839,56.3163717815305,061.10666453684766,96.51142588489624,175.0247455673889,46.55401354116538,176.09878670226257,87.42056971926803,184.43281996120035,43.53333331072109,195.86155507093572,38.22527805795094,075.01365838958247,30.60326323428011,082.30705333739482,76.48196330235604,169.3645875970939,97.71869196188608,139.53833914367223,76.03681085115882,053.9710521485623,89.20735013750205,169.07014406283025,52.74046973016765,167.94685547711617,46.67857410673128,070.66150955499435,92.92713789364831,176.97878372747498,47.57596364975532,167.37202754570876,42.83843832029179,089.67677575072079,65.79936592745237,150.53478828983,48.85581152764205,034.21206097786789,44.20952859866288,077.9240914545704,68.9723599933059,162.27101367004632,69.95445795447587,180.1901807509566,44.82162893218353,193.114388797442,38.80067033713209,061.83020602312595,50.25610789244621,038.78580379679423,64.99568095539578,061.379289447425,72.80788731317097,185.40451939411645,57.05198397627122,152.10797973193984,63.12762376881715,052.04540476831827,69.4328601204522,140.23689373545111,71.16774802184875,054.63510555424817,52.21388588061123,033.91550010906887,98.86943574220611,064.17698887494485,80.90806058670817,174.78925295941542,41.57341522824434,034.1836400264419,75.237720360134,083.90239366249155,56.30804621605327,151.54772026906181,46.85629026349976,094.44336776917852,65.56892160559052,182.36875375713919,40.61825515970618,051.0477517128865,45.8227014576001,062.22267576120188,52.0609919483679,077.19303492601364,70.45820000180959,197.77159928000232,86.7278223300282,162.07306379667647,96.76882412413983,191.56497449807442,88.69629254546599,179.9448179406932,74.16311935043758,199.2725269292572,60.99903099844988,190.54671411399852,43.39060180650027,134.52451385320009,60.396342458373,050.2864961189907,49.80453881323059,049.58667721632031,59.80895099453265,097.6456339600767,68.86157272420604,132.57720016809309,95.59854761387875,074.24869136721598,69.82457122657193,171.79646205863379,78.45356224515052,175.3956114656803,85.75993667331619,135.28611281526193,47.02051394723416,056.25381749711624,39.26147251058019,030.05882244669796,49.59297386723685,044.66826172480893,66.45008614558913,066.56089447242954,41.09209807936973,040.45755098375164,97.53518548909936,149.0725632190844,51.883211820739666,080.2795740146698,92.11606081344084,166.74671856944039,60.99139402740988,132.72283304060323,43.30717306430063,064.0393204150601,78.03168802018232,172.34649422579923,96.22759296761404,160.45788573918959,73.09499809758037,158.84095621726802,75.85844831279042,199.82785779692128,72.3692519338385,147.26426910848174,88.47586499559782,150.45815980285988,75.80985952982456,160.455555629271532,42.50840943572217,082.22666157785568,42.71987853716458,088.9138964166533,69.803788889835472,194.83450672430196,45.69430680250754,167.31925746917527,66.58935317747915,157.23870631569862,59.51428198012956,180.36675600171273,90.96014789746954,168.46852178591112,85.59430710452014,142.0754545384731,78.84478600148043,075.47770200533905,90.42453899753964,178.63542434898018,96.64742716885644,152.34800398794107,60.76950525602592,094.09433112516793,77.15910509073893,190.44855097096364,87.50879176484702,155.48216114069585,35.57070347228866,074.49269241843041,84.84513684930135,189.84580670720979,45.35828361091658,183.48916274498238,48.38028579728175,142.2617008099817,87.10385094025457,199.31500880510394,68.77540947206617,155.34001756003703,64.9319380069486,174.777589300092767,89.52981289513276,1

我看到的唯一问题是您应该写入h = sigmoid(X * theta)而不是h = sigmoid(theta * X)。更改此版本后,我从代码中获得了相同的答案,因为我从代码中获得了相同的作业。

最新更新