类型错误: fit_generator() 缺少 1 个必需的位置参数:'generator'



我正试图训练CNN来检测图像是否深度伪造,但在运行代码时,我一直得到这个错误:TypeError: fit_generator()缺少1个需要的位置参数:'generator'我如何摆脱这个错误?我的代码有问题吗?我也不确定分类器类是否必要,所以我已经包含了它,但注释掉了。

我的代码全文:

import tensorflow as tf 
config = tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth = True 
sess = tf.compat.v1.Session(config=config)
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, Concatenate, LeakyReLU
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.models import Model
# Height and width refer to the size of the image
# Channels refers to the amount of color channels (red, green, blue)
image_dimensions = {'height':256, 'width':256, 'channels':3}
# Create a Classifier class
#class Classifier():

# def __init__():
# self.model = 0

#def predict(self, x):
# return self.model.predict(x)

# def fit(self, x, y):
#return self.model.train_on_batch(x, y)

# def get_accuracy(self, x, y):
#return self.model.test_on_batch(x, y)

#def load(self, path):
# self.model.load_weights(path)
class Meso4(Model):
def __init__(self, learning_rate = 0.0001):
self.model = self.init_model()
optimizer = Adam(lr = learning_rate)
self.model.compile(optimizer = optimizer,
loss = 'mean_squared_error',
metrics = ['accuracy'])   



def init_model(self):
x = Input(shape = (image_dimensions['height'],
image_dimensions['width'],
image_dimensions['channels']))

x1 = Conv2D(8, (3, 3), padding='same', activation = 'relu')(x)
x1 = BatchNormalization()(x1)
x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)

x2 = Conv2D(8, (5, 5), padding='same', activation = 'relu')(x1)
x2 = BatchNormalization()(x2)
x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)

x3 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x2)
x3 = BatchNormalization()(x3)
x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)

x4 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x3)
x4 = BatchNormalization()(x4)
x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)

y = Flatten()(x4)
y = Dropout(0.5)
y = Dense(16)
y = LeakyReLU(alpha=0.1)
y = Dropout(0.5)
y = Dense(1, activation = 'sigmoid')

return Model(inputs = x, outputs = y)
bat_size = 64
input_size = 256
# initializing a train datagenerator
train_datagen = ImageDataGenerator(rescale=1./255)
# initializing a test datagenerator
test_datagen = ImageDataGenerator(rescale=1./255)
# preprocessing for trainig set
train_set = train_datagen.flow_from_directory(
'C:\Users\Kevin\Desktop\Train', # train data directory
target_size=(input_size, input_size), 
batch_size=bat_size,
class_mode='categorical',
color_mode= 'rgb'
)
# preprocessing for test set
test_set = test_datagen.flow_from_directory(
'C:\Users\Kevin\Desktop\Test', # test data directory
target_size=(input_size, input_size),
batch_size=bat_size,
shuffle=False,
class_mode='categorical',
color_mode= 'rgb'
)
filepath = "FYP.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='val_acc',
verbose=1,
save_best_only=True,
mode='max'
)
Meso4.fit_generator(
train_set,
steps_per_epoch=1400//bat_size + 1,
epochs=25,
callbacks=[checkpoint],
validation_data=test_set,
validation_steps=600 //bat_size + 1
)
#ERROR
TypeError                                 Traceback (most recent call last)
<ipython-input-9-00d0b295f968> in <module>
5                                 callbacks=[checkpoint],
6                                 validation_data=test_set,
----> 7                                 validation_steps=600 //bat_size + 1
8                                 )
~Anaconda3envsTflibsite-packageskeraslegacyinterfaces.py in wrapper(*args, **kwargs)
89                 warnings.warn('Update your `' + object_name + '` call to the ' +
90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
92         wrapper._original_function = func
93         return wrapper
TypeError: fit_generator() missing 1 required positional argument: 'generator'

我可以看到3或4个错误:

对于keras中的子类:

  • 您需要呼叫super(YourClass, self).__init__()
  • call方法中定义模型

查看此链接以了解有关keras子类化的更多信息

在你的y部分停止使用函数式语法

y = Flatten()(x4)
y = Dropout(0.5)
y = Dense(16)

应该是

y = Dropout(0.5)(y)
y = Dense(16)(y)

你不直接调用class只是实例化一个新对象

最新更新