在训练某些网络时,Keras(Tensorflow后端)在GPU上比在CPU上慢



我很难理解为什么 GPU 和 CPU 速度与小型网络相似(CPU 有时更快(,而 GPU 在较大尺寸的网络中更快。问题底部的代码在 i7-6700k 上运行时间为 103.7 秒,但当使用 tensorflow-GPU 时,代码在 29.5 秒内运行。

但是,当我训练一个具有 100 个隐藏神经元的网络时,而不是像下面的示例那样的 1000 个,我在使用 GPU 时得到 ~20 秒,在使用 CPU 时得到 ~15 秒。

我在另一个堆栈溢出答案中读到,CPU>GPU 传输需要很长时间,我假设这是参考在 GPU 上加载数据示例。

有人可以解释为什么会发生这种情况,并可能引用我可以对代码进行的一些更改以最大限度地提高速度吗?

import numpy as np
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.utils import np_utils
from keras.layers.core import Dense, Activation, Flatten, Dropout
from sklearn.preprocessing import normalize
## Importing the MNIST dataset using Keras
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape for vector input
N, x, y = X_train.shape
X_train = normalize(np.reshape(X_train, (N, x * y)))
N, x, y = X_test.shape
X_test = normalize(np.reshape(X_test, (N, x * y)))
# one-hot encoding
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
model = Sequential()
model.add(Dense(output_dim=750, input_dim=784))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(150))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(50))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(50))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='Nadam', metrics=['accuracy'])
fit = model.fit(X_train, y_train, batch_size=128, nb_epoch=10, verbose=0)
## Printing the accuracy of our model, according to the loss function specified in model.compile above
score = model.evaluate(X_test, y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

在微型网络的情况下,批量加载可能是这里的罪魁祸首。

Keras 在每次迭代开始时将每个小批量从 RAM 加载到 GPU,从而在微型网络中造成瓶颈(其中前向/后向计算非常快(。
您可以尝试使用model.fit_generator而不是普通fit,以便加载小批量的CPU线程并行工作。

不幸的是,我不知道在 GPU 上为 Keras 预加载整个数据集(请参阅我的问题(

如果您使用的是Tensorflow后端,则可以使用Google Timeline分析工具来查看导致速度变慢的原因。有关参考,请参阅此问题

最新更新