依赖于Keras中其他输出节点的自定义激活函数



我想使用长短期记忆(LSTM)网络预测一个多维数组,同时对感兴趣的表面的形状施加限制。

我想通过设置输出的一些元素(表面的区域)与其他元素(简单的缩放条件)的函数关系来实现这一点。

是否有可能在Keras中为输出设置这样的自定义激活函数,其参数是其他输出节点?如果没有,是否有其他接口允许这样做?你有手册的资料吗?

GitHub上的keras-team回答了关于如何创建自定义激活函数的问题。

还有一个带有自定义激活函数的代码问题。

这些页面可能对你有帮助!


附加评论

这些页面还不够这个问题所以我添加下面的评论;

也许PyTorch比Keras更适合自定义。我试着写这样一个网络,虽然它是一个非常简单的,基于PyTorch教程和"扩展PyTorch自定义激活函数">

我创建了一个自定义激活函数,其中输出向量的第1个元素(从0开始计数)等于第0个元素的两倍。我们使用了一个非常简单的单层网络进行训练。训练结束后,我检查条件是否满足。


import torch
import matplotlib.pyplot as plt
# Define the custom activation function
# reference: https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa
def silu(input):
input[:,1] = input[:,0] * 2
return input 
class SiLU(torch.nn.Module):
def __init__(self):
super().__init__() # init the base class
def forward(self, input):
return silu(input) # simply apply already implemented SiLU

# Training
# reference: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
k = 10
x = torch.rand([k,3])
y = x * 2
model = torch.nn.Sequential(
torch.nn.Linear(3, 3),
SiLU()  # custom activation function
)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
for t in range(2000):
y_pred = model(x)
loss = loss_fn(y_pred, y)
if t % 100 == 99:
print(t, loss.item())
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# check the behaviour
yy = model(x)  # predicted
print('ground truth')
print(y)
print('predicted')
print(yy)

# examples for the first five data
colorlist = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
plt.figure()
for i in range(5):
plt.plot(y[i,:].detach().numpy(), linestyle = "solid", label = "ground truth_" + str(i), color=colorlist[i])
plt.plot(yy[i,:].detach().numpy(), linestyle = "dotted", label = "predicted_" + str(i), color=colorlist[i])
plt.legend()
# check if the custom activation works correctly
plt.figure()
plt.plot(yy[:,0].detach().numpy()*2, label = '0th * 2')
plt.plot(yy[:,1].detach().numpy(), label = '1th')
plt.legend()
print(yy[:,0]*2)
print(yy[:,1])

最新更新