我有以下代码。这是在类CNNEnv中。
def step(self, a_t, i):
self.a_t = a_t
self.i = i
# Batch size of 32
# 1875 * 32 = 60000 -> # of training samples
self.X_train = self.X_train[self.i * 32:(self.i + 1) * 32]
self.y_train = self.y_train[self.i * 32:(self.i + 1) * 32]
# Train & evaluate one iteration
self.model.train_on_batch(self.X_train, self.y_train)
self.scores = self.model.test_on_batch(self.X_test, self.y_test)
self.scores = self.scores[1] * 100
return self.X_train.shape, self.y_train.shape, self.scores
下面是调用这个的外部脚本。它对第一次迭代有效。然而,在第二次迭代时,出现了一个错误。
from CNN import CNNEnv
# Instantiate class and assign to object env
env = CNNEnv()
# Call function within class
a, b, c = env.step(0.001, 1)
print(a)
print(b)
print(c)
# Call function within class second time
d, e, f = env.step(0.001, 2)
print(d)
print(e)
print(f)
第一批错误
First batch:
(32, 1, 28, 28)
(32, 10)
9.42000001669
Error on second batch:
F tensorflow/stream_executor/cuda/cuda_dnn.cc:422] could not set cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM
奇怪的是,如果我要做以下事情,当我调用train_on_batch
两次而不使用类时,它就工作了。但是我需要这个类,因为我的外部脚本必须以这种方式调用函数。任何想法?
解决了。用户错误。变量赋值问题
应该使用self。X_batch代替self.X_train。