SQLAlchemy:提前在关系列上动态应用过滤器



我有两个SQLAlchemy类表示多对一关系,比如:

class Person
id = Column(Integer(), primary_key=True)
name = Column(String(30))
known_addresses = relationship('Address', backref='person')
class Address:
id = Column(Integer(), primary_key=True)
person_id = Column(Integer(), ForeignKey(Person.id, ondelete='cascade'))
city = Column(String(30))
zip_code = Column(String(10))

现在,假设我有一个函数,它返回按邮政编码过滤的Person查询集(Select对象(:

def get_persons_in_zip_code(zip_code):
return session.query(Person).
join(Address).
where(Address.zip_code == zip_code)

一旦我返回查询集,我就无法控制它,预计这将封装我正在使用的框架(在我的例子中是Django/DRF(的所有数据,以呈现人员列表以及他们的地址(因此代码迭代查询集,为每个人调用.addresses并进行呈现(。

问题是:我想确保调用.addresses将只返回原始zip_code过滤查询中匹配的地址,而不是所有与此人相关的地址。

有没有一种方法可以在SQLAlchemy中实现这一点,而无需访问后期返回的Person对象也就是说,我只能修改我的get_persons_in_zip_code函数或原始的SQLAlchemy类,但不能访问从查询返回的Person对象,因为这发生在框架呈现代码的深处。

EDIT:同样重要的是,对返回的查询对象调用count()会产生预期的Person对象数,而不是Address对象数。

您想要的似乎是contains_eager

EDIT:一个更新版本,它对.count()函数进行猴痘修补,只返回不同的Person计数。

from sqlalchemy import Integer, Column, String, ForeignKey
from sqlalchemy import create_engine, func, distinct
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker, contains_eager
from types import MethodType
engine = create_engine('sqlite:///:memory:', echo=True)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base()

class Person(Base):
__tablename__ = "person"
id = Column(Integer(), primary_key=True)
name = Column(String(30))
known_addresses = relationship('Address', backref='person')
def __repr__(self):
return "<Person {}>".format(self.name)

class Address(Base):
__tablename__ = "address"
id = Column(Integer(), primary_key=True)
person_id = Column(Integer(), ForeignKey(Person.id, ondelete='cascade'))
city = Column(String(30))
zip_code = Column(String(10))
def __repr__(self):
return "<Address {}>".format(self.zip_code)

Base.metadata.create_all(engine)
p1 = Person(name="P1")
session.add(p1)
p2 = Person(name="P2")
session.add(p2)
session.commit()
a1 = Address(person_id=p1.id, zip_code="123")
session.add(a1)
a2 = Address(person_id=p1.id, zip_code="345")
session.add(a2)
a3 = Address(person_id=p2.id, zip_code="123")
session.add(a3)
a4 = Address(person_id=p1.id, zip_code="123")
session.add(a4)
session.commit()
def get_persons_in_zip_code(zip_code):
return session.query(Person).
join(Person.known_addresses).
filter(Address.zip_code == zip_code).
options(contains_eager(Person.known_addresses))
def distinct_person_count(q):
count_q = q.statement.with_only_columns([func.count(distinct(Person.id))])
return q.session.execute(count_q).scalar()

results = get_persons_in_zip_code("123")
results.count = MethodType(distinct_person_count, results)
print(results.count())

for person in results:
print(person)
for address in person.known_addresses:
print(address)

输出:

2
<Person P1>
<Address 123>
<Address 123>
<Person P2>
<Address 123>

最新更新