np.random.randint 导致 ValueError: low >= high



我正在这里开发CapsNet,它是在具有10个数字的MNIST数据集上实现的,但我已经更改了代码以使用具有三个类的数据集。模型训练和测试工作良好,但操纵潜在功能会导致错误:

def manipulate_latent(model, data, args):
x_test, y_test = data
index = np.argmax(y_test, 1) == args.digit
print(index)
number = np.random.randint(low=0, high=sum(index) - 1)
x, y = x_test[index][number], y_test[index][number]
x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
noise = np.zeros([1, 3, 16])
x_recons = []
for dim in range(16):
for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
tmp = np.copy(noise)
tmp[:,:,dim] = r
x_recon = model.predict([x, y, tmp])
x_recons.append(x_recon)
x_recons = np.concatenate(x_recons)
img = combine_images(x_recons, height=16)
image = img*255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)

输出为:

number=np随机随机随机数(low=0,high=sum(index(-1(ValueError:低>=高

函数调用:

model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
manipulate_latent(manipulate_model, (x_test, y_test), args)

这是因为您使用的是sum()而不是len()

x = [False, False, False]
print(sum(x))
print(len(x))

输出

0
3

请注意,False数组的sum()等于0。而CCD_ 5是阵列的大小。

最新更新