from tensorflow.keras.layers import ReLU, MaxPooling2D, Input, Dense, Conv2D, Flatten
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback
import tensorflow as tf
import numpy as np
class MyLayer(Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv = None
self.m_max = None
self.relu = None
def call(self, inputs, **kwargs):
x = self.conv(inputs)
x = self.m_max(x)
return self.relu(x)
def build(self, input_shape):
self.conv = Conv2D(input_shape=input_shape, filters=128, kernel_size=(2,2))
self.m_max = MaxPooling2D()
self.relu = ReLU()
class ModelCallback(Callback):
def on_batch_end(self, batch, logs=None):
inp = Input((32,32,3))
x = MyLayer()(inp)
x = Flatten()(x)
out = Dense(1)(x)
model = Model(inputs=inp, outputs=out)
model.compile(optimizer='adam', loss='categorical_crossentropy' )
x_train = np.random.rand(5000,32,32,3)
y_train = np.random.randint(2, size=(5000,1))
model.fit(x_train, y_train,epochs=5, callbacks=ModelCallback())
def _calculate_shape(self, input_tensor_shape: tf.TensorShape):
self.conv.trainable = False
self.m_max.trainable = False
self.relu.trainable = False
input_shape = list(input_tensor_shape)
input_shape[0] = self.batch_size
x = self.conv(np.random.rand(*input_shape))
self.conv_shapes = (input_shape[1:], tf.shape(x).numpy().tolist()[1:]) # [1:] is needed to remove the batch size form the shape
x = self.m_max(x)
self.max_shapes = (self.conv_shapes[1], tf.shape(x).numpy().tolist()[1:])
x = self.relu(x)
self.relu_shapes = (self.max_shapes[1], tf.shape(x).numpy().tolist()[1:])
self.conv.trainable = True
self.m_max.trainable = True
self.relu.trainable = True