可变批大小的Tensorflow训练



是否可以使用Tensorflow/Keras并训练具有可变批大小的模型?每个历元都有不同大小的批。我认为使用高水平tf。keras API这是不可能的吗?

您可以从头开始编写自定义训练循环。

下面是一个简单的例子:

import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(1, input_shape=(1,), activation="sigmoid")
])
# creating some random data with different batch sizes, 32, 7, 1
data = [
(tf.random.normal((32,1)), tf.random.normal((32,))),
(tf.random.normal((7,1)), tf.random.normal((7,))),
(tf.random.normal((1,1)), tf.random.normal((1,))),
]
loss_func = tf.losses.MSE
opt = tf.optimizers.SGD()
for x,y in data:
print(f"Batch size: {x.shape[0]}")
with tf.GradientTape() as tape:
pred = model(x)
loss = loss_func(y, pred)
grad = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grad, model.trainable_variables))

您可以使用model.train_on_batch:

import tensorflow as tf
import random
import numpy as np
from sklearn.datasets import load_iris
data, target = load_iris(return_X_y=True)
both = list(zip(data.tolist(), target.tolist()))
random.shuffle(both)
data, target = map(list, zip(*both))
ds = [
[np.array(data[:13]), np.array(target[:13])],
[np.array(data[13:44]), np.array(target[13:44])],
[np.array(data[44:77]), np.array(target[44:77])],
[np.array(data[77:121]), np.array(target[77:121])],
[np.array(data[121:]), np.array(target[121:])]
]
model = tf.keras.Sequential([
tf.keras.layers.Dense(8, input_shape=(4,), activation="relu"),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(3, activation="softmax")
])
model.compile(loss='sparse_categorical_crossentropy', 
optimizer='adam',
metrics=['accuracy'])
for i in range(100):
for x, y in ds:
loss, acc = model.train_on_batch(x, y)
if not i % 10:
print(f'epoch {i:2d} loss {loss:=5.3f}, acc {acc:=6.2%}')
epoch  0 loss 1.144, acc 37.93%
epoch 10 loss 0.960, acc 37.93%
epoch 20 loss 0.831, acc 65.52%
epoch 30 loss 0.712, acc 65.52%
epoch 40 loss 0.606, acc 65.52%
epoch 50 loss 0.522, acc 68.97%
epoch 60 loss 0.440, acc 86.21%
epoch 70 loss 0.371, acc 93.10%
epoch 80 loss 0.317, acc 96.55%
epoch 90 loss 0.273, acc 96.55%

相关内容

  • 没有找到相关文章

最新更新