我有一个大数据集,它的初始维度是(453732,839)
。在此数据集中,我有相互关联的子组,这些子组具有可变维度。由于我必须训练 LSTM,因此每个子组的大小必须相同,因此我为每个子组应用填充,以便它们都具有相同的长度。 填充后,数据集变为大约2000000
行。
所以我在一个循环中执行model.fit()
函数,其中model.fit()
为数据集的每个部分执行一个。在循环中,我在线填充数据集的一部分以传递给model.fit()
,但在第二部分,在model.fit()
之前,RAM 填满了,我无法继续训练。
这是我pad
和fit
模型的代码:
training_set_portion_size = int(training_dataset.shape[0] / 6)
start_portion_index = 0
for epoch in range(0, 50):
for part in range(0, 4):
end_portion_index = start_portion_index + training_set_portion_size
training_set_portion = training_dataset[start_portion_index:end_portion_index]
training_set_portion_labels = training_set_portion[:, training_set_portion.shape[1]-1]
portion_groups = get_groups_id_count(training_set_portion[:,0])
# Scale dataset portion
training_set_portion = scaler.transform(training_set_portion[:,0:training_set_portion.shape[1]-1])
training_set_portion = np.concatenate((training_set_portion, training_set_portion_labels[:, np.newaxis]), axis=1)
# Pad dataset portion
training_set_portion = pad_groups(training_set_portion, portion_groups)
training_set_portion_labels = training_set_portion[:, training_set_portion.shape[1]-1]
# Exluding group and label from training_set_portion
training_set_portion = training_set_portion[:, 1:training_set_portion.shape[1] - 1]
# Reshape data for LSTM
training_set_portion = training_set_portion.reshape(int(training_set_portion.shape[0]/timesteps), timesteps, features)
training_set_portion_labels = training_set_portion_labels.reshape(int(training_set_portion_labels.shape[0]/timesteps), timesteps)
model.fit(training_set_portion, training_set_portion_labels, validation_split=0.2, shuffle=False, epochs=1,
batch_size=1, workers=0, max_queue_size=1, verbose=1)
* **更新 ***
我现在正在使用pandas
,带有chunksize
,但似乎张量在循环中连接。
pandas
迭代器:
training_dataset_iterator = pd.read_csv('/content/drive/My Drive/Tesi_magistrale/en-train.H',
chunksize=80000, sep=",", header=None, dtype=np.float64)
新代码:
for epoch in range(0, 50):
for chunk in training_dataset_iterator:
training_set_portion = chunk.values
training_set_portion_labels = training_set_portion[:, training_set_portion.shape[1]-1]
portion_groups = get_groups_id_count(training_set_portion[:,0])
# Scale dataset portion
training_set_portion = scaler.transform(training_set_portion[:,0:training_set_portion.shape[1]-1])
training_set_portion = np.concatenate((training_set_portion, training_set_portion_labels[:, np.newaxis]), axis=1)
# Pad dataset portion
print('Padding portion...n')
training_set_portion = pad_groups(training_set_portion, portion_groups)
training_set_portion_labels = training_set_portion[:, training_set_portion.shape[1]-1]
# Exluding group and label from training_set_portion
training_set_portion = training_set_portion[:, 1:training_set_portion.shape[1] - 1]
# Reshape data for LSTM
training_set_portion = training_set_portion.reshape(int(training_set_portion.shape[0]/timesteps), timesteps, features)
training_set_portion_labels = training_set_portion_labels.reshape(int(training_set_portion_labels.shape[0]/timesteps), timesteps)
print('Training set portion shape: ', training_set_portion.shape)
model.fit(training_set_portion, training_set_portion_labels, validation_split=0.2, shuffle=False, epochs=1,
batch_size=1, workers=0, max_queue_size=1, verbose=1)
第一个print('Training set portion shape: ', training_set_portion.shape)
给了我(21327, 20, 837)
,但第二个给了我(43194, 20, 837)
。我不明白为什么。
更新 2
我注意到training_set_portion = pad_groups(training_set_portion, portion_groups)
以某种方式重复数据。
垫组代码:
def pad_groups(dataset, groups):
max_subtree_length= 20
start = 0
rows, cols = dataset.shape
padded_dataset = []
index = 1
for group in groups:
pad = [group[0]] + [0] * (cols - 1)
stop = start + group[1]
subtree = dataset[start:stop].tolist()
padded_dataset.extend(subtree)
subtree_to_pad = max_subtree_length - group[1]
pads = [pad] * subtree_to_pad
padded_dataset.extend(pads)
start = stop
index+=1
padded_dataset = np.array(padded_dataset)
return padded_dataset
我该怎么做? 提前谢谢你。
我在TowardsDataScience中找到了一个链接,他们向您展示了3种方法来解决此问题,方法是使用一个名为pandas
的小型库,该库广泛用于数据集处理。我希望它对解决您的问题有所帮助。这是链接:-
https://towardsdatascience.com/3-simple-ways-to-handle-large-data-with-pandas-d9164a3c02c1
问候 尼尔·古普塔
我解决了我的问题:我的代码中有一个错误。