将hdf5文件转换为tf.data数据集



我想通过tf.data.Dataset.from_tensor_slices将多个hdf5文件转换为tf.data.Dataset。使用:

dataset = tf.data.Dataset.from_tensor_slices(filepath) #filepath:list containing all hdf5 filespaths
dataset = (dataset
.shuffle(1024)
.map(load_files, num_parallel_calls=AUTOTUNE)
.cache()
.repeat()
.batch(BS)
.prefetch(AUTOTUNE)
)

我用包装器编写了load_file方法,因为我不能使用通常用于例如图像的tf.io.read_filetf.io.decode_png

def load_file(file):
hf = h5py.File(file.numpy(),'r')
epsilon = np.array(hf.get('epsilon')) #array of (128,128,1)
field = np.array(hf.get('field')) #array of (128,128,6)
hf.close()
return epsilon, field
def wrapper(file):
e,f = tf.py_function(load_files, [file],(tf.float64,tf.float64))
return e,f

对我的数据集进行迭代得到:

for e,f in dataset.take(5):
print(e[0][0][0], f[0][0][0])
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(-2.815057100053552e-33, shape=(), dtype=float64)
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(-2.5008625074043214e-33, shape=(), dtype=float64)
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(1.352042249055261e-33, shape=(), dtype=float64)
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(8.932832186890058e-34, shape=(), dtype=float64)
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(-1.0549327174460344e-33, shape=(), dtype=float64)
for e in dataset.take(5):
print(e)
(<tf.Tensor: shape=(128, 128, 1), dtype=float64, numpy=
array([[[1.],....), <tf.Tensor: shape=(128, 128, 6), dtype=float64, numpy=array([[[-2.81505710e-33,...)

然而,当我试图使用来训练我的自动编码器时:

m.fit(dataset, epochs=args.ep,callbacks = [tboard_callback])
Autodef getModel():
model = Sequential()
model.add(Conv2D(64, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same', input_shape=(128, 128, 1)))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(32, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(16, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same'))

model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(16, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation=keras.layers.LeakyReLU(alpha=0.01), padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(6, (3, 3), activation='linear', padding='same'))
model.compile(optimizer='adam', loss='mean_squared_error')  #Using binary cross entropy loss. Try other losses.
model.summary()
return modelencoder:

我得到以下错误:

Epoch 1/3
2022-08-18 16:58:43.843892: W tensorflow/core/common_runtime/forward_type_inference.cc:231] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
type_id: TFT_PRODUCT
args {
type_id: TFT_TENSOR
args {
type_id: TFT_BOOL
}
}
}
is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
type_id: TFT_PRODUCT
args {
type_id: TFT_TENSOR
args {
type_id: TFT_LEGACY_VARIANT
}
}
}
while inferring type of node 'mean_squared_error/cond/output/_10'
2022-08-18 16:58:43.944734: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at conv_ops_fused_impl.h:679 : INVALID_ARGUMENT: input must be 4-dimensional[128,128,1]
Traceback (most recent call last):
File "/home/lukas/Documents/ba-lukas/modelmain.py", line 154, in <module>
main()
File "/home/lukas/Documents/ba-lukas/modelmain.py", line 116, in main
m.fit(dataset, epochs=args.ep,callbacks = [tboard_callback])
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node 'sequential/conv2d/BiasAdd' defined at (most recent call last):
File "/home/lukas/Documents/ba-lukas/modelmain.py", line 154, in <module>
main()
File "/home/lukas/Documents/ba-lukas/modelmain.py", line 116, in main
m.fit(dataset, epochs=args.ep,callbacks = [tboard_callback])
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 1409, in fit
tmp_logs = self.train_function(iterator)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 1051, in train_function
return step_function(self, iterator)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 1040, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 1030, in run_step
outputs = model.train_step(data)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 889, in train_step
y_pred = self(x, training=True)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/training.py", line 490, in __call__
return super().__call__(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/sequential.py", line 374, in call
return super(Sequential, self).call(inputs, training=training, mask=mask)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/functional.py", line 458, in call
return self._run_internal_graph(
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
outputs = node.layer(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/home/lukas/anaconda3/envs/nano/lib/python3.9/site-packages/keras/layers/convolutional/base_conv.py", line 269, in call
outputs = tf.nn.bias_add(
Node: 'sequential/conv2d/BiasAdd'
input must be 4-dimensional[128,128,1]
[[{{node sequential/conv2d/BiasAdd}}]] [Op:__inference_train_function_1448]

错误可能在哪里?为什么输入需要是4维的?谢谢我从这里得到了这个想法的教程:在这里输入链接描述

我找到了解决方案。我完全忘记了在我的数据集上调用batch((函数:

dataset =(dataset.map(pygen.wrapper, num_parallel_calls=AUTOTUNE).cache().batch(args.bs).prefetch(AUTOTUNE))

现在我可以训练我的模特了。只会出现警告:

W tensorflow/core/common_runtime/forward_type_inference.cc:231] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
type_id: TFT_PRODUCT
args {
type_id: TFT_TENSOR
args {
type_id: TFT_BOOL
}
}
}
is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
type_id: TFT_PRODUCT
args {
type_id: TFT_TENSOR
args {
type_id: TFT_LEGACY_VARIANT
}
}
}
while inferring type of node 'mean_squared_error/cond/output/_10

最新更新