我正试图训练CNN来检测图像是否深度伪造,但在运行代码时,我一直得到这个错误:TypeError: fit_generator()缺少1个需要的位置参数:'generator'我如何摆脱这个错误?我的代码有问题吗?我也不确定分类器类是否必要,所以我已经包含了它,但注释掉了。
我的代码全文:
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, Concatenate, LeakyReLU
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.models import Model
# Height and width refer to the size of the image
# Channels refers to the amount of color channels (red, green, blue)
image_dimensions = {'height':256, 'width':256, 'channels':3}
# Create a Classifier class
#class Classifier():
# def __init__():
# self.model = 0
#def predict(self, x):
# return self.model.predict(x)
# def fit(self, x, y):
#return self.model.train_on_batch(x, y)
# def get_accuracy(self, x, y):
#return self.model.test_on_batch(x, y)
#def load(self, path):
# self.model.load_weights(path)
class Meso4(Model):
def __init__(self, learning_rate = 0.0001):
self.model = self.init_model()
optimizer = Adam(lr = learning_rate)
self.model.compile(optimizer = optimizer,
loss = 'mean_squared_error',
metrics = ['accuracy'])
def init_model(self):
x = Input(shape = (image_dimensions['height'],
image_dimensions['width'],
image_dimensions['channels']))
x1 = Conv2D(8, (3, 3), padding='same', activation = 'relu')(x)
x1 = BatchNormalization()(x1)
x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
x2 = Conv2D(8, (5, 5), padding='same', activation = 'relu')(x1)
x2 = BatchNormalization()(x2)
x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
x3 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x2)
x3 = BatchNormalization()(x3)
x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
x4 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x3)
x4 = BatchNormalization()(x4)
x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
y = Flatten()(x4)
y = Dropout(0.5)
y = Dense(16)
y = LeakyReLU(alpha=0.1)
y = Dropout(0.5)
y = Dense(1, activation = 'sigmoid')
return Model(inputs = x, outputs = y)
bat_size = 64
input_size = 256
# initializing a train datagenerator
train_datagen = ImageDataGenerator(rescale=1./255)
# initializing a test datagenerator
test_datagen = ImageDataGenerator(rescale=1./255)
# preprocessing for trainig set
train_set = train_datagen.flow_from_directory(
'C:\Users\Kevin\Desktop\Train', # train data directory
target_size=(input_size, input_size),
batch_size=bat_size,
class_mode='categorical',
color_mode= 'rgb'
)
# preprocessing for test set
test_set = test_datagen.flow_from_directory(
'C:\Users\Kevin\Desktop\Test', # test data directory
target_size=(input_size, input_size),
batch_size=bat_size,
shuffle=False,
class_mode='categorical',
color_mode= 'rgb'
)
filepath = "FYP.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='val_acc',
verbose=1,
save_best_only=True,
mode='max'
)
Meso4.fit_generator(
train_set,
steps_per_epoch=1400//bat_size + 1,
epochs=25,
callbacks=[checkpoint],
validation_data=test_set,
validation_steps=600 //bat_size + 1
)
#ERROR
TypeError Traceback (most recent call last)
<ipython-input-9-00d0b295f968> in <module>
5 callbacks=[checkpoint],
6 validation_data=test_set,
----> 7 validation_steps=600 //bat_size + 1
8 )
~Anaconda3envsTflibsite-packageskeraslegacyinterfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
TypeError: fit_generator() missing 1 required positional argument: 'generator'
我可以看到3或4个错误:
对于keras中的子类:
- 您需要呼叫
super(YourClass, self).__init__()
- 在
call
方法中定义模型
查看此链接以了解有关keras子类化的更多信息
在你的y部分停止使用函数式语法
y = Flatten()(x4)
y = Dropout(0.5)
y = Dense(16)
应该是
y = Dropout(0.5)(y)
y = Dense(16)(y)
你不直接调用class只是实例化一个新对象