我是python新手。我试图导入参数到我的类">inverse_model";。我调用函数">get_models"这样做。但它给了我错误">init()得到一个意外的关键字参数'zz':
我很感激你的帮助。请参阅下面的代码:def get_models(args):
zz=torch.tensor(args.chi_Initialize)
inverse_net = inverse_model(in_channels=len(args.chi),zz=zz,resolution_ratio=args.resolution_ratio,nonlinearity=args.nonlinearity)
return inverse_net
class inverse_model(nn.Module):
def __init__(self, in_channels,zz,resolution_ratio=6,nonlinearity="tanh"):
super(inverse_model, self).__init__()
self.in_channels = in_channels
self.zz=zz
self.resolution_ratio = resolution_ratio #vertical scale mismtach between seismic and EI
self.activation = nn.ReLU() if nonlinearity=="relu" else nn.Tanh()
在默认参数之后,python不允许有非默认参数。
修改您的构造函数
def __init__(self, in_channels,
chi1,chi2,chi3,chi4,chi5,chi6,chi7,chi8,chi9,
chi10,chi11,chi12, resolution_ratio=6,nonlinearity="tanh"):
更新答案:Revision 1(更新问题)
import torch
from torch import nn
class inverse_model(nn.Module):
def __init__(self, in_channels, zz, resolution_ratio=6, nonlinearity="tanh"):
super(inverse_model, self).__init__()
self.in_channels = in_channels
self.zz = zz
self.resolution_ratio = resolution_ratio # vertical scale mismtach between seismic and EI
self.activation = nn.ReLU() if nonlinearity == "relu" else nn.Tanh()
def get_models(args):
zz = torch.tensor(args.chi_Initialize)
inverse_net = inverse_model(in_channels=len(args.chi), zz=zz, resolution_ratio=args.resolution_ratio,
nonlinearity=args.nonlinearity)
return inverse_net
返回exit 0
作为状态。