# We first define the observations as a list and then also as a table for the experienced worker's performance.
Observation1 = [2.0, 6.0, 2.0]
Observation2 = [1.0, 5.0, 7.0]
Observation3 = [5.0, 2.0, 1.0]
Observation4 = [2.0, 3.0, 8.0]
Observation5 = [4.0, 4.0, 0.0]
ObservationTable = [
Observation1,
Observation2,
Observation3,
Observation4,
Observation5
]
# Then we define our learning rate, number of observations, and the epoch counters we will be utilizing (10, 100, and 1000).
LearningRate = 0.01
ObservationCounter = 5
EpochVersion1 = 10
EpochVersion2 = 100
EpochVersion3 = 1000
# Thus, we are now ready to define the Stochastic Gradient Descent Algorithm:
def StochasticGradientDescent(EpochCounter):
Theta0 = 10.0
Theta1 = 0.0
Theta2 = -1.0
while (EpochCounter != 0):
ObservationCounter = 5
while (ObservationCounter >= 0):
Theta0_Old = float(Theta0)
Theta1_Old = float(Theta1)
Theta2_Old = float(Theta2)
n = 5 - ObservationCounter
x = ObservationTable [n]
x0 = float(x[0])
x1 = float(x[1])
x2 = float(x[2])
Theta0_New = Theta0_Old - LearningRate*[(Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2)]
Theta1_New = Theta1_Old - LearningRate*[(Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2)]*float(x0)
Theta2_New = Theta2_Old - LearningRate*[(Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2)]*float(x1)
print(Theta0_New, Theta1_New, Theta2_New)
ObservationCounter -= 1
else:
EpochCounter -= 1
if (EpochCounter == 0):
print(Theta0_New, Theta1_New, Theta2_New)
StochasticGradientDescent(int(EpochVersion1))
代码输出TypeError:无法将序列与"float"类型的非int相乘。我已经在每个可能的步骤中将值转换为浮点值,但错误仍然存在。关键行主要是与SGD的定义功能相关的行。
我对随机梯度下降了解不多,但我在您的代码中发现了两个改进。
首先,错误是因为您试图将浮点值与列表相乘并将其添加到浮点值中。这是通过使用圆形括号而不是方形来解决的:
Theta0_New = Theta0_Old - LearningRate*((Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2))
Theta1_New = Theta1_Old - LearningRate*((Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2))*float(x0)
Theta2_New = Theta2_Old - LearningRate*((Theta0_Old+Theta1_Old*float(x0)+Theta2_Old*float(x1))-float(x2))*float(x1)
其次,while循环应该提前结束一次迭代,否则在尝试访问ObservationTable[5]
时会出现错误。
因此,将while循环更改为:
while (ObservationCounter >= 1):
输出:
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.98 -0.04 -1.12
10.02 0.02 -0.9
9.93 -0.35000000000000003 -1.1400000000000001
10.01 0.02 -0.97
9.94 -0.24 -1.24
9.94 -0.24 -1.24
最后一句话:您不需要将EpochVersion1转换为整数:int(EpochVersion1)
。当您将其声明为EpochVersion1 = 10
时,它已经是一个整数。