使用 Tensorflow 2.0 从 TfRecords 中的 VarLenDeature 加载 3D 数组



我正在尝试将 TFRecord Data 中的 3D 数组加载到 Tensorflow 中。 数据不是我创建的,我刚刚得到了image_feature_description字典。 但我不确定我所做的是否正确,似乎不正确。

def read_tf_tens(path, file=None):
print("Inside read_tf_tens")
print(path)
raw_dataset = tf.data.TFRecordDataset(path)
print("jup")
image_feature_description = {
'height': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'width': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'depth': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'spacing': tf.io.FixedLenFeature([3], tf.float32, default_value=[0.0, 0.0, 0.0]),
'origin': tf.io.FixedLenFeature([3], tf.float32, default_value=[0.0, 0.0, 0.0]),
'mask': tf.io.VarLenFeature(tf.int64),
'image': tf.io.VarLenFeature(tf.float32),
}
#@tf.function
def _parse_image_function(proto):
return tf.io.parse_single_example(proto, image_feature_description)
parsed_dataset = raw_dataset.map(_parse_image_function)
#These were needed, because Tensorflow else told be that The Varibles needed to be initialized      
height = np.int64(256)
width = np.int64(256)
depth = np.int64(709)
spacing = np.float32(256)
origin = np.float32(256)
mask = np.zeros((depth,height,width),dtype=np.int64)
image = np.zeros((depth,height,width),dtype=np.float32)

for image_features in parsed_dataset:
height = image_features['height']
width = image_features['width']
depth = image_features['depth']
spacing = image_features['spacing']
origin = image_features['origin']
mask = image_features['mask'].values
image = image_features['image'].values
shape = (depth[0] , width[0], height[0])
print(shape)
image_array = tf.reshape(image,shape)
print(image_array)
mask_array = tf.reshape(mask, shape)
print(mask_array.shape)
return image_array, mask_array, shape, spacing, origin

这是输出:

(<tf.Tensor 'strided_slice:0' shape=<unknown> dtype=int64>, <tf.Tensor 'strided_slice_1:0' shape=<unknown> dtype=int64>, <tf.Tensor 'strided_slice_2:0' shape=<unknown> dtype=int64>)
Tensor("Reshape:0", shape=(None, None, None), dtype=float32)
(None, None, None)

如果我总共打印我的数据集(它包含数组的内容和一个整数标签。

Sets:  [<DatasetV1Adapter shapes: ((None, None, None), ()), types: (tf.float32, tf.int32)>]

现在形状(无,无,无(让我担心。 这似乎不对。 有人可以在这里帮我吗?

所以这是一个奇怪的。 显然,加载 numpy 值然后在张量中重塑它对我来说是诀窍。

def read_tf_tens(path, file):
read_file = path + '/' + file
raw_dataset = tf.data.TFRecordDataset(read_file)
image_feature_description = {
'height': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'width': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'depth': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
'spacing': tf.io.FixedLenFeature([3], tf.float32, default_value=[0.0, 0.0, 0.0]),
'origin': tf.io.FixedLenFeature([3], tf.float32, default_value=[0.0, 0.0, 0.0]),
'mask': tf.io.VarLenFeature(tf.int64),
'image': tf.io.VarLenFeature(tf.float32),
}
def _parse_image_function(proto):
return tf.io.parse_single_example(proto, image_feature_description)
parsed_dataset = raw_dataset.map(_parse_image_function)
for image_features in parsed_dataset:
height = image_features['height'].numpy()
width = image_features['width'].numpy()
depth = image_features['depth'].numpy()
spacing = image_features['spacing'].numpy()
origin = image_features['origin'].numpy()
mask = image_features['mask'].values.numpy()
image = image_features['image'].values.numpy()
shape = (depth[0], width[0], height[0]) # without data augmentation
#image_array = np.reshape(image, shape)
#mask_array = np.reshape(mask, shape)
image_array = tf.keras.backend.reshape(image, shape)
mask_array = tf.keras.backend.reshape(mask, shape)
return image_array, mask_array, shape, spacing, origin

最新更新