根据输入返回函数的类型推断



我想为基类定义一个函数,并为派生类调用获得正确的返回类型。例如

# Module 1:
from typing import TypeVar
class Food:
pass
class Animal:
def __init__(self, food: Food) -> None:
self.food=food
T = TypeVar("T", bound=Food)
S = TypeVar("S", bound=Animal)
def get_food(animal: S) -> T:  # Illustrates what I want but not working.
return animal.food
food = get_food(Animal(Food()))
reveal_type(food)  # Food.
# Module 2:
class Carrot(Food):
pass
class Rabbit(Animal):
def __init__(self, food: Carrot) -> None:
self.food=food
food = get_food(Rabbit(Carrot()))
reveal_type(food)  # Food. Want Carrot.

我知道的选项有:

  1. 使用@overload装饰器,但这意味着模块1需要知道模块2中的继承类型-这是一个问题
  2. 在模块2中有一个新的get_food,它委托给模块1并显式地强制转换返回类型:
def get_food(rabbit: Rabbit) -> Carrot:
return cast(Carrot, get_food(rabbit))

有更好的方法吗?

您需要使您的Animal类通用食品类型。这基本上意味着任何[非严格]Animal子类都有某种食物(Food的[非严格]子类)与之相关联。

from typing import Generic, TypeVar
class Food:
pass

_F = TypeVar("_F", bound=Food)

class Animal(Generic[_F]):
def __init__(self, food: _F) -> None:
self.food = food

def get_food(animal: Animal[_F]) -> _F:
return animal.food

food = get_food(Animal(Food()))
reveal_type(food)  # N: Revealed type is "__main__.Food"

class Carrot(Food):
pass
class Rabbit(Animal[Carrot]):
pass

food = get_food(Rabbit(Carrot()))
reveal_type(food)  # N: Revealed type is "__main__.Carrot"

这里有一个playground链接和一个关于泛型类的相关文档。

最新更新