tf.data.Dataset.map() 用于由多个切片组成的数据集



从单个切片创建的数据集的tf.data.Dataset.map()看起来像dataset.map(lambda x: x/2)。如果数据集是从两个切片创建的,会是什么样子?例如,请参阅以下代码。代码最后一行中的map()函数适用于从单个切片创建的数据集,但会导致我的双切片情况出错。

import tensorflow as tf, numpy as np     # tensorflow 2.0
from tensorflow import keras as kr
dataset = tf.data.Dataset.from_tensor_slices((features_int8, labels_int8)) # features, labels are numpy arrays
model = kr.Sequential()
model.add(kr.layers.InputLayer(6)
model.add(kr.layers.Dense(     8, activation=tf.nn.tanh))
model.add(kr.layers.Dense(     3, activation=tf.nn.tanh))
model.compile(optimizer = kr.optimizers.RMSprop(), loss = kr.losses.MeanSquaredError())
model.fit(dataset.batch(64).map(lambda x: x/9), epochs = 10)

在单独的函数中传递 lambda 函数,如下所示

def map_fn(x, y):
return x / 9, y
model.fit(dataset.batch(64).map(map_fn), epochs = 10)

最新更新