当试图在tensorflow中设置权重时,似乎会生成随机权重



我试图编写代码来手动设置keras网络中的权重,但当我构建网络时,似乎有额外的权重。

>>> import numpy as np
>>> import tensorflow as tf
>>> from tensorflow import keras as ke
>>> from tensorflow.keras import layers
>>> lay=layers.Dense(1,activation="relu")
>>> lay.add_weight(shape=(1,),)
<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([-0.05657911], dtype=float32)>
>>> lay.set_weights(np.array([[0.5]]))
>>> lay.get_weights()
[array([0.5], dtype=float32)] # EXPEXCTD
>>> net=ke.Sequential([ke.Input(shape=(1,)),lay])
>>> net.get_weights()
[array([0.5], dtype=float32), array([[1.4100171]], dtype=float32), array([0.], dtype=float32)] # ACTUAL
>>> net.get_layer(index=0).get_weights()
[array([0.5], dtype=float32), array([[1.4100171]], dtype=float32), array([0.], dtype=float32)] # ACTUAL

正如您所看到的,在我构建网络时,它额外创建了2个numpy数组。为什么会这样?这些额外的重量是干什么的?他们是偏见吗?为什么一个只有2个神经元的网络会有3个不同的权重?我应该如何设置它们?

编辑:

有人建议先建立网络,然后设置权重。这也不起作用:

>>> import numpy as np
>>> import tensorflow as tf
>>> from tensorflow import keras as ke
>>> from tensorflow.keras import layers
>>> key=layers.Dense(1,activation="sigmoid")
>>> net=ke.Sequential([ke.Input(shape=(1,)),key])
>>> key.set_weights(np.array([[0.5]]))
Traceback (most recent call last):
File "<pyshell#44>", line 1, in <module>
key.set_weights(np.array([[0.5]]))
File "D:UsersStudentAppDataLocalProgramsPythonPython38libsite-packageskerasenginebase_layer.py", line 1832, in set_weights
raise ValueError(
ValueError: You called `set_weights(weights)` on layer "dense" with a weight list of length 1, but the layer was expecting 2 weights. Provided weights: [[0.5]]...

目前还不清楚阵列的形状

解决方案是在创建网络后设置权重,但必须在不添加权重的情况下设置权重,并且必须以包含二维numpy数组和一维numpy数组的列表形式给出权重。即:

>>> import numpy as np
>>> import tensorflow as tf
>>> from tensorflow import keras as ke
>>> from tensorflow.keras import layers
>>> key=layers.Dense(1,activation="sigmoid")
>>> net=ke.Sequential([ke.Input(shape=(1,)),key])
>>> key.set_weights([np.array([[0.5]]), np.array([0.])])
>>> key.get_weights()
[array([[0.5]], dtype=float32), array([0.], dtype=float32)]

最新更新