如何避免Python中重复的kwargs ?



假设有一个方法:

def train_model(self, out_dir='./out/',
test_size=0.2, train_size=None,
random_state=None, shuffle=True, stratify=None,
epochs=DEFAULT_EPOCHS, batch_size=DEFAULT_BATCH_SIZE):
...
self.model.train(test_size=test_size, train_size=train_size, random_state=random_state, shuffle=shuffle, stratify=stratify, epochs=epochs, batch_size=batch_size)

在这个函数内部,将调用具有相同签名的另一个方法,然后我必须手动传递所有参数。我不想在train_model中使用kwargs,因为它是一个可能被其他人使用的公共方法,所以我希望保留键入信息。我不知道是否有方法可以让我在外部函数的kwargs中保留键入信息。

在TypeScript中,同样的功能可以使用Parameters工具类型来实现。例如,
function sum(a: int, b: int) {
return a + b;
}
type SumParamsType = Paramters<typeof sum>
// Then you can use the SumPramsType in other places.

Python的一个失败示例:

from typing import TypeVar
T = TypeVar('T')
def f1(a=1, b=2, c=3):
return a+b+c
# Is there anything like T=Parameters(type(f1)) in Python?
def f2(z=0, **kwargs: T):
return z+f1(**kwargs)
# T cannot capture the kwargs of f1 (of course it won't)

这个也不行:

def f1(a=1, b=2, c=3):
return a+b+c
def f2(z=0, **kwargs: f1.__annotations__['kwargs']):
return z + f1(**kwargs)
# kwargs has the type Any

您可以获得的最接近的是使用TypedDictUnpack(可在Python <3.11 viatyping_extensions):

from typing_extensions import Unpack, TypedDict, NotRequired

class Params(TypedDict):
a: NotRequired[int]
b: NotRequired[int]
c: NotRequired[int]

def f1(**kwargs: Unpack[Params]):
a = kwargs.pop('a', 1)
b = kwargs.pop('b', 1)
c = kwargs.pop('c', 1)
return a + b + c

def f2(z=0, **kwargs: Unpack[Params]):
return z + f1(**kwargs)

注意,如果您的IDE不使用mypy --enable-incomplete-feature=Unpack,则可能不支持Unpack。VSCode支持开箱即用。PyCharm,可能不会。

如果你控制了两个函数定义,你可能会发现更容易改变你的方法来接受一个dataclass封装所有参数和它们的默认值,而不是单独接受每个参数。

您可以创建一个包含训练参数的类,并将其传递给train方法,就像HuggingFace Transformers库

中所做的那样下面是他们GitHub上的代码:

from dataclasses import asdict, dataclass, field, fields
#...
@dataclass
class TrainingArguments:
framework = "pt"
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
prediction_loss_only: bool = field(
default=False,
metadata={"help": "When performing evaluation and predictions, only returns the loss."},
)
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
# ...

这是一个有点冗长,但非常清楚,将工作与您的IDE类型提示。

您将需要locals().__code__.co_varnames的组合:

def f2(b=5, c=7):
return b*c
def f1(a=1, b=2, c=3):
sss = locals().copy()
f2_params = f2.__code__.co_varnames
return f2(**{x:y for x, y in sss.items() if x in f2_params})
print(f1())
>>> 6
<标题>

编辑如果你想使用**kwargs,试试这个:

def f2(b=5, c=7):
return b*c
def f1(a=1, **kwargs):
sss = locals()['kwargs'].copy()
f2_params = f2.__code__.co_varnames
return f2(**{x:y for x, y in sss.items() if x in f2_params})
print(f1(b=10, c=3))

最新更新