为Keras LSTM读取多个CSV



我正在尝试使用KERAS实现LSTM网络,但我在接受输入方面遇到问题。我的数据集是多个CSV文件的形式(所有文件具有相同的尺寸68x250,每个条目包含2个值)。各个类之间大约有200个CSV文件。CSVS之一的预览

如何将这些多个CSV作为输入?

我最近做了类似的事情,因为佩德罗说,您使用fit_generator并写下您的自定义生成器。

这是生成器的示例:

def generator(files):
    print('start generator')
    while 1:        
        print('loop generator')
        for file in files:
            try:                 
                df = pd.read_csv(file)
                batches = int(np.ceil(len(df)/batch_size))      
                for i in range(0, batches):                                   
                    yield pad_batch(df[i*batch_size:min(len(df), i*batch_size+batch_size)])
            except EOFError:
                print("error" + file) 

您将文件名列表传递给发电机,然后通过文件迭代并批次返回内容。在我的情况下,load_data是读取pandas并进行一些预处理的函数。pad_batch对LSTM进行填充。

用法:

model.fit_generator(
      generator=generator(trainingFiles),   
      steps_per_epoch=steps,
      epochs=num_epochs,
      validation_data=[x_test, y_test],
      verbose=1)

定义一个实现:https://keras.io/utils/#sequence

并使用方法model.fit_generator。

最新更新