亲爱的stackoverflow成员,
我目前正在尝试实现我自己的keras调谐器训练循环。在这个循环中,我想通过模型多次传递输入变量,例如:
Y = Startvalue
for i in range(x):
Y = model(Y)
我想看看这个方法是否为我的自我反馈问题创造了更稳定的模拟。当我实现它,我得到一个OOM错误,即使我不循环。当我正常操作时,不会出现此错误。我的Class示例(当我将logits切换为logits2时发生OOM错误:
)class MyTuner(kt.Tuner):
def run_trial(self, trial, train_ds, validation_data):
model = self.hypermodel.build(trial.hyperparameters)
optimizer = tf.keras.optimizers.Adam()
epoch_loss_metric = tf.keras.metrics.MeanSquaredError()
def microbatch(T_IN, A_IN, D_IN):
OUT_T = []
OUT_A = []
for i in range(len(T_IN)):
A_IN_R = tf.expand_dims(tf.squeeze(A_IN[i]), 0)
T_IN_R = tf.expand_dims(tf.squeeze(T_IN[i]), 0)
D_IN_R = tf.expand_dims(tf.squeeze(D_IN[i]), 0)
(OUT_T_R, OUT_A_R) = model((A_IN_R, T_IN_R, D_IN_R))
OUT_T.append(tf.squeeze(OUT_T_R))
OUT_A.append(tf.squeeze(OUT_A_R))
return(tf.squeeze(tf.stack(OUT_T)), tf.squeeze(tf.stack(OUT_A)))
def run_train_step(data):
T_IN = tf.dtypes.cast(data[0][0], 'float32')
A_IN = tf.dtypes.cast(data[0][1], 'float32')
D_IN = tf.dtypes.cast(data[0][2], 'float32')
A_Ta = tf.dtypes.cast(data[1][0], 'float32')
T_Ta = tf.dtypes.cast(data[1][1], 'float32')
mse = tf.keras.losses.MeanSquaredError()
with tf.GradientTape() as tape:
logits2 = microbatch(T_IN, A_IN, D_IN)
logits = model([A_IN, T_IN, D_IN])
loss = mse((T_Ta, A_Ta), logits2)
# Add any regularization losses.
if model.losses:
loss += tf.math.add_n(model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss_metric.update_state((T_Ta, A_Ta), logits2)
return loss
for epoch in range(1000):
print('Epoch: {}'.format(epoch))
self.on_epoch_begin(trial, model, epoch, logs={})
for batch, data in enumerate(train_ds):
self.on_batch_begin(trial, model, batch, logs={})
batch_loss = float(run_train_step(data))
self.on_batch_end(trial, model, batch, logs={'loss': batch_loss})
if batch % 100 == 0:
loss = epoch_loss_metric.result().numpy()
print('Batch: {}, Average Loss: {}'.format(batch, loss))
epoch_loss = epoch_loss_metric.result().numpy()
self.on_epoch_end(trial, model, epoch, logs={'loss': epoch_loss})
epoch_loss_metric.reset_states()
````
在我的理解中,微批函数没有实现自反馈循环(尽管它不影响OOM)
我猜发生的事情是,因为你计算网络的输出k次,网络消耗的内存量增加了k倍(因为它需要为backprop存储中间张量)。
你能做的是,在每个自反馈实例中,你反向支持梯度,这样所有的中间张量不会超过极限。
如果你有任何疑问,请告诉我,