我正在尝试使用KERAS实现LSTM网络,但我在接受输入方面遇到问题。我的数据集是多个CSV文件的形式(所有文件具有相同的尺寸68x250,每个条目包含2个值)。各个类之间大约有200个CSV文件。CSVS之一的预览
如何将这些多个CSV作为输入?
我最近做了类似的事情,因为佩德罗说,您使用fit_generator并写下您的自定义生成器。
这是生成器的示例:
def generator(files):
print('start generator')
while 1:
print('loop generator')
for file in files:
try:
df = pd.read_csv(file)
batches = int(np.ceil(len(df)/batch_size))
for i in range(0, batches):
yield pad_batch(df[i*batch_size:min(len(df), i*batch_size+batch_size)])
except EOFError:
print("error" + file)
您将文件名列表传递给发电机,然后通过文件迭代并批次返回内容。在我的情况下,load_data
是读取pandas并进行一些预处理的函数。pad_batch
对LSTM进行填充。
用法:
model.fit_generator(
generator=generator(trainingFiles),
steps_per_epoch=steps,
epochs=num_epochs,
validation_data=[x_test, y_test],
verbose=1)
定义一个实现:https://keras.io/utils/#sequence
并使用方法model.fit_generator。