save()接受1个位置参数,但给出了2个


import deepchem as dc
import pandas as pd
import numpy as np
import os, glob
tasks, datasets, transformers = dc.molnet.load_hiv(featurizer='GraphConv')
train_dataset, valid_dataset, test_dataset = datasets
print(datasets)
n_tasks = len(tasks)
model = dc.models.GraphConvModel(n_tasks, mode='classification')
hist = model.fit(train_dataset, nb_epoch=50)
metric = dc.metrics.Metric(dc.metrics.roc_auc_score)
print('Training set score:', model.evaluate(train_dataset, [metric], transformers))
print('Test set score:', model.evaluate(test_dataset, [metric], transformers))
model.save("HIV_test1.h5")

我想保存模型,但它有错误。

TypeError:save((接受1个位置参数,但为2个提供了

deepchemmodel.save()的文档实际上并没有说太多。但是您不能为它提供文件名(它不需要额外的参数,只需要model(。

事实证明,为了";保存";,初始化模型时需要指定一个model_dir。我没有检查,但显然您的模型将在某些步骤后保存到该位置,如.fit()

使用此示例并应用于您的代码:

# while creating
model = dc.models.GraphConvModel(n_tasks, mode='classification',
model_dir='/home/user/HIV_test1')
# or model_dir=r'C:UsersmeHIV_test1'
hist = model.fit(train_dataset, nb_epoch=50)
# etc.

# and later to load it:
model = dc.models.GraphConvModel(model_dir='/home/user/HIV_test1')
model.restore()

(请确保您使用的model_dir始终指向同一个目录。仅使用文件名意味着执行脚本的当前目录可能会有所不同。此外,这应该只是一个目录路径,没有文件名-该目录用于该模型的检查点。因此,对不同的模型使用单独的目录。(

此外,请查看这些相关方法:

  • Model.get_model_filename()
  • TorchModel.save_checkpoint()(TorchModel(
  • TorchModel.get_checkpoints()(TorchModel(

最新更新