Tensorflow 2自定义数据集序列



我在python字典中有一个数据集。结构如下:

data.data['0']['input'],data.data['0']['target'],data.data['0']['length']

inputtarget都是尺寸为(n,)的阵列,而length是内部

我用tf.keras.utils.Sequence创建了一个类对象,并将__getitem__指定为:

def __getitem__(self, idx):
idx = str(idx)
return {
'input': np.asarray(self.data[idx]['input']),
'target': np.asarray(self.data[idx]['target']),
'length': self.data[idx]['length']
}

如何使用tf.data.Dataset对这样的数据集进行迭代?如果我尝试使用from_tensor_slices,就会出现此错误

ValueError:尝试将不支持类型的值(<class'dict'>(转换为张量。

我认为您应该按照这里的建议将字典修改为张量将字典转换为张量或者将字典更改为文本文件或tfrecords。希望这对你有帮助!

相关内容

  • 没有找到相关文章

最新更新