如何使用迭代器接口将多个矩阵加载到Chainer模型中



我有一个关于Chainer迭代器接口的问题,以及它如何与Trainer、Updater和Model接口。

我的数据是图形,因此具有不同的矩阵形状。我已经将特征矩阵连接成一个大的稠密矩阵,将邻接矩阵连接到一个大稀疏COO矩阵,将求和算子连接到一一个大疏COO矩阵。因为这是用分子数据完成的,所以每个样本都有一个原子图和一个键图。因此,输入数据是一个六元组,出于深度学习的目的,我认为这是一个巨大的训练数据点。(在我使用这个巨大的矩阵之前,我还不打算进行训练/测试拆分,以保持我的代码简单。)

xs = (atom_Fs, atom_As, atom_Ss, bond_Fs, bond_As, bond_Ss)
ts = data['target'].values
dataset = [(xs, ts)]

我的模型前向传球如下:

# model boilerplate above this comment
def forward(self, data):
atom_feats, atom_adjs, atom_sums, bond_feats, bond_adjs, bond_sums = data
atom_feats = self.atom_mp1(atom_feats, atom_adjs)
atom_feats = self.atom_mp2(atom_feats, atom_adjs)
atom_feats = self.atom_gather(atom_feats, atom_sums)
bond_feats = self.atom_mp1(bond_feats, bond_adjs)
bond_feats = self.atom_mp2(bond_feats, bond_adjs)
bond_feats = self.atom_gather(bond_feats, bond_sums)
feats = F.hstack([atom_feats, bond_feats])
feats = F.tanh(self.dense1(feats))
feats = F.tanh(self.dense2(feats))
feats = self.dense3(feats)
return feats

然后我把所有东西都交给一个培训师:

from chainer import iterators, training
from chainer.optimizers import SGD, Adam
iterator = iterators.SerialIterator(dataset, batch_size=1)
optimizer = Adam()
optimizer.setup(mpnn)
updater = training.updaters.StandardUpdater(iterator, optimizer)
max_epoch = 50
trainer = training.Trainer(updater, (max_epoch, 'epoch'))
trainer.run()

然而,当我运行训练器时,我会出现以下错误:

Exception in main training loop: forward() takes 2 positional arguments but 3 were given
Traceback (most recent call last):
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py", line 315, in run
update()
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py", line 165, in update
self.update_core()
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py", line 177, in update_core
optimizer.update(loss_func, *in_arrays)
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/optimizer.py", line 680, in update
loss = lossfun(*args, **kwds)
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/link.py", line 242, in __call__
out = forward(*args, **kwargs)
Will finalize trainer extensions and updater before reraising the exception.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-45-ea26cece43b3> in <module>
9 max_epoch = 50
10 trainer = training.Trainer(updater, (max_epoch, 'epoch'))
---> 11 trainer.run()
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
327                 f.write('Will finalize trainer extensions and updater before '
328                         'reraising the exception.n')
--> 329             six.reraise(*sys.exc_info())
330         finally:
331             for _, entry in extensions:
~/anaconda/envs/mpnn/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
691             if value.__traceback__ is not tb:
692                 raise value.with_traceback(tb)
--> 693             raise value
694         finally:
695             value = None
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
313                 self.observation = {}
314                 with reporter.scope(self.observation):
--> 315                     update()
316                     for name, entry in extensions:
317                         if entry.trigger(self):
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py in update(self)
163 
164         """
--> 165         self.update_core()
166         self.iteration += 1
167 
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py in update_core(self)
175 
176         if isinstance(in_arrays, tuple):
--> 177             optimizer.update(loss_func, *in_arrays)
178         elif isinstance(in_arrays, dict):
179             optimizer.update(loss_func, **in_arrays)
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/optimizer.py in update(self, lossfun, *args, **kwds)
678         if lossfun is not None:
679             use_cleargrads = getattr(self, '_use_cleargrads', True)
--> 680             loss = lossfun(*args, **kwds)
681             if use_cleargrads:
682                 self.target.cleargrads()
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/link.py in __call__(self, *args, **kwargs)
240         if forward is None:
241             forward = self.forward
--> 242         out = forward(*args, **kwargs)
243 
244         # Call forward_postprocess hook
TypeError: forward() takes 2 positional arguments but 3 were given

这让我很困惑,因为我以与mnist示例相同的方式设置了数据集,其中输入数据与元组中的输出数据成对。由于Chainer中的抽象层,我不太确定如何调试这个问题。有人对此有什么见解吗?

您是否使用只获取xs(或data)并输出featsmpnn模型?我认为问题出在模型上,而不是迭代器或数据集。

您需要准备一个模型,将xsts作为输入参数,并计算loss作为输出。例如,

class GraphNodeClassifier(chainer.Chain):
def __init__(self, mpnn):
with self.init_scope():
self.mpnn = mpnn
def forward(self, xs, ts):
feat = self.mpnn(xs)
loss = "calculate loss between `feat` and `ts` here..."
return loss

并使用该CCD_ 8作为优化器的CCD_。

在上面的MNIST示例中,它使用chainer内置的L.Classifier类,该类封装MLP模型(仅获得x)以获得xt来计算分类损失。

最新更新