我有一个csv文件(280 MB(,我使用以下代码将其加载到tensorflow中:
import tensorflow as tf
data = tf.data.experimental.make_csv_dataset("flight_2018.csv",
batch_size = 1000,
label_name="Cancelled",
num_epochs = 20,
num_parallel_reads=2)
该对象的type
是tensorflow.python.data.ops.dataset_ops.PrefetchDataset
。
所以我想知道如何将这个数据集拆分为训练和测试数据集,这些数据集在拆分之前会被打乱。
由于数据类型为tensorflow.python.data.ops.dataset_ops.PrefetchDataset
,因此可以使用take和skip方法来拆分数据。
train_data=data.take(160)
test_data=data.skip(160)
请参考这个要点。非常感谢。