如何为以值数组为标签的回归生成损失函数(Graph执行错误)



我想使用预训练的分类模型进行回归。

base_model  = InceptionV3(weights='imagenet')
x = base_model.output
x = Dense(1000, activation='relu')(x)
x = Dense(128, activation='relu')(x)
prediction = Dense(1, activation='linear')(x)
model = Model(inputs=base_model.input, outputs=prediction)
for layer in base_model.layers:
layer.trainable = True

输入是图像,输出是类似(125.258155.2163(的数组。我不知道该如何使用损失函数。我使用了"mse",当我训练它时。我得到了这个错误:

Epoch 1/25
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-26-a0aaf400ecb5> in <module>()
5     callbacks=callbacks,
6     validation_data=valid_dataloader,
----> 7     validation_steps=len(valid_dataloader),
8 
9 )
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53     ctx.ensure_initialized()
54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
56   except core._NotOkStatusException as e:
57     if name is not None:
InvalidArgumentError: Graph execution error:
Detected at node 'mean_squared_error/remove_squeezable_dimensions/Squeeze' defined at (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
app.start()
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
self.io_loop.start()
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
self._run_once()
File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
handle._run()
File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.7/dist-packages/tornado/ioloop.py", line 758, in _run_callback
ret = callback()
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 536, in <lambda>
self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-26-a0aaf400ecb5>", line 7, in <module>
validation_steps=len(valid_dataloader),
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1384, in fit
tmp_logs = self.train_function(iterator)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function
return step_function(self, iterator)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step
outputs = model.train_step(data)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in train_step
loss = self.compute_loss(x, y, y_pred, sample_weight)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 919, in compute_loss
y, y_pred, sample_weight, regularization_losses=self.losses)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 201, in __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 141, in __call__
losses = call_fn(y_true, y_pred)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 242, in call
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
File "/usr/local/lib/python3.7/dist-packages/keras/utils/losses_utils.py", line 188, in squeeze_or_expand_dimensions
y_true, y_pred)
File "/usr/local/lib/python3.7/dist-packages/keras/utils/losses_utils.py", line 130, in remove_squeezable_dimensions
labels = tf.squeeze(labels, [-1])
Node: 'mean_squared_error/remove_squeezable_dimensions/Squeeze'
Can not squeeze dim[2], expected a dimension of 1, got 2
[[{{node mean_squared_error/remove_squeezable_dimensions/Squeeze}}]] [Op:__inference_train_function_79237]

我是神经网络的初学者。我与colab合作,这是我的鳕鱼链接:https://colab.research.google.com/drive/1DDUSKyiQvQ3VCV9cMR74EE5ar7x_mCBe?usp=sharing
非常感谢您的建议

当您有多输出回归时,您可以对每个输出使用单个网络:

class MultiOutputModel():
def make_default_hidden_layers(self, inputs):
x = Conv2D(16, (3, 3), padding="same")(inputs)
x = Activation("relu")(x)
x = BatchNormalization(axis=-1)(x)
x = MaxPooling2D(pool_size=(3, 3))(x)
x = Dropout(0.25)(x)
x = Conv2D(32, (3, 3), padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(axis=-1)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Conv2D(64, (3, 3), padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(axis=-1)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
return x
def build_X_coordinate(self, inputs):
x = self.make_default_hidden_layers(inputs)
x = Flatten()(x)
x = Dense(100)(x)
x = Activation("relu")(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(1)(x)
x = Activation("linear", name="X_coordinate")(x)
return x
def build_Y_coordinate(self, inputs):   
x = self.make_default_hidden_layers(inputs)
x = Flatten()(x)
x = Dense(100)(x)
x = Activation("relu")(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(1)(x)
x = Activation("linear", name="Y_coordinate")(x)
return x
def assemble_full_model(self, width, height):
input_shape = (height, width, 3)
inputs = Input(shape=input_shape)
X_branch = self.build_X_coordinate(inputs)
Y_branch = self.build_Y_coordinate(inputs)
model = Model(inputs=inputs,outputs = [X_branch, Y_branch ])
return model

init_lr = 1e-4
optim = keras.optimizers.Adam(learning_rate=init_lr, decay=init_lr / 25)
model.compile(optimizer=optim, 
loss={'X_coordinate': tf.keras.losses.MeanSquaredLogarithmicError(), 
'Y_coordinate': tf.keras.losses.MeanSquaredLogarithmicError()},
metrics={'X_coordinate': 'mae','Y_coordinate': 'mae'}
)

相关内容

  • 没有找到相关文章

最新更新