我已经为二元分类数据集编写了一个MLP ANN代码,并且我的训练数据集的准确率为0.88
(88%)。我的测试数据集为我提供了0.37 - 0.55
准确性。
我注意到这是由于我的参数在使用 UpdateParameters 方法后没有更新,如下所示:
def update_parameters(parameters, grads, lr):
param1 = parameters
L = len(parameters) // 2
for l in range(L):
parameters["W" + str(l+1)] = parameters["W" + str(l+1)] - lr * grads["dW"+str(l+1)]
parameters["b" + str(l+1)] = parameters["b" + str(l+1)] - lr * grads["db"+str(l+1)]
print(param1==parameters)
return parameters
上面的函数为我提供了所有初始值和更新值比较True
。
在以下函数中调用更新参数函数:
def ann(X, Y, dimensions, lr, lr_decay, batch_size, epochs, loss, activations, gradient_alg):
L = len(dimensions) # number of layers in the neural networks
m = X.shape[1]
costs = [] # to keep track of the cost
parameters = initialize_parameters(dimensions)
param1 = parameters
if (gradient_alg == "b"):
batch_size = X.shape[1]
for i in range(epochs):
minibatches = random_mini_batches(X, Y, batch_size)
cost_total = 0
for minibatch in minibatches:
(minibatch_X,minibatch_Y) = minibatch
last_A, caches = forward_prop_layers(minibatch_X, parameters, activations)
cost_total += compute_cost(last_A, minibatch_Y, loss)
gradients = backward_prop_layers(last_A, minibatch_Y, caches, activations)
parameters = update_parameters(parameters, gradients, lr)
cost_avg = cost_total /m
if i %10 == 0:
print ("Cost after epoch %i: %f" %(i, cost_avg))
costs.append(cost_avg)
plt.plot(costs)
plt.ylabel('cost')
plt.xlabel('epochs')
plt.title("Learning rate = " + str(lr))
plt.show()
parameters1 = [parameters, param1, dimensions, activations, costs, lr, batch_size]
return parameters1
我的函数没有被正确调用吗?我在实施中究竟哪里出了问题?
哦,是的,这就是它返回 True 的原因。首先,您将param1
分配给parameters
。然后你正在更新parameters
.但是由于param1
指向parameters
,即使在更新parameters
之后,param1
仍然指向与parameters
相同的内存位置。在python中,一切都被视为一个对象。尝试在更新前后打印出一些parameters
,然后使用deepcopy手动检查它们是否正在更改或创建parameters
的副本,这会将parameters
中的所有内容复制到单独的内存位置。
from copy import deepcopy
def update_parameters(parameters, grads, lr):
param1 = deepcopy(parameters)
L = len(parameters) // 2
for l in range(L):
parameters["W" + str(l+1)] = parameters["W" + str(l+1)] - lr * grads["dW"+str(l+1)]
parameters["b" + str(l+1)] = parameters["b" + str(l+1)] - lr * grads["db"+str(l+1)]
print(param1==parameters)
return parameters
还要尝试在每次迭代后打印出损失,如果它正在更改,则parameters
正在更新,如果没有,则您的parameters
没有得到正确更新。