我正在尝试使用tf.data.experimental.make_csv_dataset
来读取一组大型csv文件。我特别使用它pandas.read_csv
,因为它经常给我内存错误,所以 tensorflow 在读取我的文件时要快得多。我的阅读代码如下所示:
url_list = ["file_1.csv", "file_2.csv", ... "file_20.csv"]
batch_size = 40000
train_dataset = tf.data.experimental.make_csv_dataset(
url_list,
batch_size,
label_name = "Outcome",
shuffle = False,
num_epochs = 1)
问题是我有一列"金额",它根据第一行的读数int32
dtype给出。这是有问题的,因为后面的一些值是 2000000000,需要存储为int64
。 如何指定"金额"列应解释为int64
?我尝试在使用tf.cast
读取文件后转换它们:
for feature, label in train_dataset:
tf.cast(feature["Amount"],tf.int64)
但是我收到此错误:InvalidArgumentError: Field 6 in record is not a valid int32: 2123456789
我知道tf.experimental.make_csv_dataset
的column_names
和column_default
论据,但是,我的专栏太多了。因此,指定它们并分配"dtype"将非常乏味。我还有字符串、浮点数和整数的组合,因此tf.experimental.make_csv_dataset
比使用tf.data.experimental.CsvDataset
更方便预处理数据,这也需要指定列名。我该如何解决这个问题?
column_defaults = [tf.string, tf.int32, tf.int64, tf.float32]
select_columns = ["str_col", "int32_col", "Amount", "Outcome"]
pattern = "file_*.csv"
batch_size = 40000
train_dataset = tf.data.experimental.make_csv_dataset(
pattern,
batch_size,
label_name = "Outcome",
shuffle = False,
num_epochs = 1,
column_defaults=column_defaults,
select_columns=select_columns
)