训练损失是 Nan 使用 TFrecords 在 TPU 中使用图像分割



我是一个初学者,试图在Kaggle Kernels中使用Tensorflow来处理TPU。我之前使用 GPU 中的数据集训练了一个 Unet 模型,现在我正在尝试在 TPU 中实现它。我用数据集图像和掩码制作了一个 tfrecord,TFrecord 返回图像和掩码。当我尝试在 TPU 中训练时,损失始终是 Nan,即使指标准确性正常。由于这与我在GPU中使用的模型和损失相同,因此我猜测问题出在tfrecord或加载数据集中。 加载数据的代码如下:

def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=3)
image = tf.cast(image, tf.float32) / (255.0)  # convert image to floats in [0, 1] range
image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
return image
def decode_image_mask(image_data):
image = tf.image.decode_jpeg(image_data, channels=3)
image = tf.cast(image, tf.float64) / (255.0)  # convert image to floats in [0, 1] range
image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
image=tf.image.rgb_to_grayscale(image)
image=tf.math.round(image)
return image
def read_tfrecord(example):
TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"mask": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
}
example = tf.io.parse_single_example(example, TFREC_FORMAT)
image = decode_image(example['image'])
mask=decode_image_mask(example['mask'])
return image, mask 

def load_dataset(filenames, ordered=False):
# Read from TFRecords. For optimal performance, reading from multiple files at once and
# disregarding data order. Order does not matter since we will be shuffling the data anyway.
ignore_order = tf.data.Options()
if not ordered:
ignore_order.experimental_deterministic = False # disable order, increase speed
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
return dataset

def get_training_dataset():
dataset = load_dataset(TRAINING_FILENAMES)
dataset = dataset.repeat() # the training dataset must repeat for several epochs
dataset = dataset.shuffle(2048)
dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
def get_validation_dataset(ordered=False):
dataset = load_dataset(VALIDATION_FILENAMES, ordered=ordered)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.cache()
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset

def count_data_items(filenames):
# the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
n = [int(re.compile(r"-([0-9]*).").search(filename).group(1)) for filename in filenames]
return np.sum(n)

那么,我做错了什么?

事实证明,问题是我正在取消批处理数据并将其批处理到 20 以在 matplotlib 中正确查看图像和掩码,这搞砸了数据发送到模型的方式,因此 Nan 丢失。制作数据集的另一个副本并使用它来查看图像,同时发送原始数据集进行训练解决了这个问题。

最新更新