我已经训练了cnn模型并将参数保存在五个文件中,但是当我使用这些参数测试照片时,我遇到了这样的问题:在这里输入图像描述
load_data的代码为:
def load_data(pag_name):``
k = 0
for filename in os.listdir(pag_name):
if (filename != '.DS_Store'):
k = k + 1
num = k
# test_per = k*4
print k
i = 0
j = 0
label = 0
train_set = numpy.empty((num, 1, 56, 56))
while (j < 1):
for filename in os.listdir(pag_name):
if (filename != '.DS_Store'):
filename = pag_name+ '/' + filename
image = Image.open(filename)
#print image.size
#print image
img_ndarray = numpy.asarray(image, dtype='float64') / 256
img_ndarray = numpy.asarray([img_ndarray])
# train_set[i] = numpy.ndarray.flatten(img_ndarray)
train_set[i] = img_ndarray
#print train_set.shape
# print filename1
# print 'label:', label
# print 'i:',i
i = i + 1
j = j + 1
def shared_dataset(data_x, borrow=True):
shared_x = theano.shared(numpy.asarray(data_x,
dtype=theano.config.floatX),
borrow=borrow)
return shared_x
train_set = shared_dataset(train_set)
print train_set.get_value(borrow=True).shape
return train_set
use_CNN的代码为:
def use_CNN(pag_name,nkerns=[20,40,60]):
data = load_data(pag_name)
data_num = data.get_value(borrow=True).shape[0]
layer0_params,layer01_params,layer1_params,layer2_params,layer3_params = load_params()
x = T.matrix('x')
layer0_input = x.reshape((data_num,1,56,56))
layer0 = LeNetConvPoolLayer(
input=layer0_input,
params_W = layer0_params[0],
params_b = layer0_params[1],
image_shape=(data_num, 1, 56, 56),
filter_shape=(nkerns[0], 1, 5,5),
poolsize=(2, 2)`
)
我还没有遇到这个问题,我不知道在哪里以及如何修改我的代码
这个错误的结果是参数不是4D的,我加载的参数是3D的,就像我的W和b是(20,1,5,5),但我加载(1,5,5),所以我遇到了这个问题