在具有TF_HUB的keras中实施BERT



我试图使用tensorflow hub在tensorflow-keras中实现Google Bert模型。为此,我设计了一个自定义的keras层"Bertlayer"。现在的问题是,当我编译 keras 模型时,它一直显示

属性错误:"Bertlayer"对象没有属性"_keras_style">

我不知道我错在哪里,_keras_style属性是什么。请帮助查找代码中的错误。

这是完整代码的github链接:https://github.com/PradyumnaGupta/BERT/blob/master/Untitled21.ipynb

class BertLayer(tf.layers.Layer):
    def __init__(self, n_fine_tune_layers=10, **kwargs):
        self.n_fine_tune_layers = n_fine_tune_layers
        self.trainable = True
        self.output_size = 768
        super(BertLayer, self).__init__(**kwargs)
    def build(self, input_shape):
        self.bert = hub.Module(
            bert_path,
            trainable=self.trainable,
            name="{}_module".format(self.name)
        )
        trainable_vars = self.bert.variables
        # Remove unused layers
        trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
        # Select how many layers to fine tune
        trainable_vars = trainable_vars[-self.n_fine_tune_layers :]
        # Add to trainable weights
        for var in trainable_vars:
            self._trainable_weights.append(var)
        for var in self.bert.variables:
            if var not in self._trainable_weights:
                self._non_trainable_weights.append(var)
        super(BertLayer, self).build(input_shape)
    def call(self, inputs):
        inputs = [K.cast(x, dtype="int32") for x in inputs]
        input_ids, input_mask, segment_ids = inputs
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "pooled_output"
        ]
        return result
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_size)

所以,tensorflow version 1.* 有点误导。它实际上有 2 个称为 Layer 的基类。一个 - 您正在使用的那个。它旨在通过常规 TF 操作实现快捷方式包装器。另一个from tensorflow.keras.layers import Layer是类似Keras的模型和续集。

从您的错误来看,您正在使用 keras/模型来进一步训练。

您可能应该从keras.layers.Layer而不是tf.layers.Layer开始对图层进行剥离。

最新更新