我有一个程序(ASGI服务器),其结构大致如下:
import asyncio
import contextvars
ctxvar = contextvars.ContextVar("ctx")
async def lifepsan():
ctxvar.set("spam")
async def endpoint():
assert ctxvar.get() == "spam"
async def main():
ctx = contextvars.copy_context()
task = asyncio.create_task(lifepsan())
await task
task = asyncio.create_task(endpoint())
await task
asyncio.run(main())
因为生命周期事件/端点在任务中运行,它们不能共享上下文变量。这是经过设计的:任务在执行之前复制上下文,因此lifespan
不能正确设置ctxvar
。这是端点所期望的行为,但我希望执行看起来像这样(从用户的角度来看):
async def lifespan():
ctxvar.set("spam")
await endpoint()
换句话说,端点在它们自己独立的上下文中执行,但是在生命周期的上下文中执行。
我试图通过使用contextlib.copy_context()
:
import asyncio
import contextvars
ctxvar = contextvars.ContextVar("ctx")
async def lifepsan():
ctxvar.set("spam")
print("set")
async def endpoint():
print("get")
assert ctxvar.get() == "spam"
async def main():
ctx = contextvars.copy_context()
task = ctx.run(asyncio.create_task, lifepsan())
await task
endpoint_ctx = ctx.copy()
task = endpoint_ctx.run(asyncio.create_task, endpoint())
await task
asyncio.run(main())
以及:
async def main():
ctx = contextvars.copy_context()
task = asyncio.create_task(ctx.run(lifespan))
await task
endpoint_ctx = ctx.copy()
task = asyncio.create_task(endpoint_ctx.run(endpoint))
await task
然而,contextvars.Context.run
似乎不以这种方式工作(我猜上下文是绑定时,协程被创建,但不是当它被执行)。
是否有一种简单的方法来实现期望的行为,而不需要重组任务的创建方式或诸如此类?
灵感来自PEP 555和asgiref:
from contextvars import Context, ContextVar, copy_context
from typing import Any
def _set_cvar(cvar: ContextVar, val: Any):
cvar.set(val)
class CaptureContext:
def __init__(self) -> None:
self.context = Context()
def __enter__(self) -> "CaptureContext":
self._outer = copy_context()
return self
def sync(self):
final = copy_context()
for cvar in final:
if cvar not in self._outer:
# new contextvar set
self.context.run(_set_cvar, cvar, final.get(cvar))
else:
final_val = final.get(cvar)
if self._outer.get(cvar) != final_val:
# value changed
self.context.run(_set_cvar, cvar, final_val)
def __exit__(self, *args: Any):
self.sync()
def restore_context(context: Context) -> None:
"""Restore `context` to the current Context"""
for cvar in context.keys():
try:
cvar.set(context.get(cvar))
except LookupError:
cvar.set(context.get(cvar))
用法:
import asyncio
import contextvars
ctxvar = contextvars.ContextVar("ctx")
async def lifepsan(cap: CaptureContext):
with cap:
ctxvar.set("spam")
async def endpoint():
assert ctxvar.get() == "spam"
async def main():
cap = CaptureContext()
await asyncio.create_task(lifepsan(cap))
restore_context(cap.context)
task = asyncio.create_task(endpoint())
await task
asyncio.run(main())
sync()
方法是在任务长时间运行并且需要在任务完成之前捕获上下文的情况下提供的。一个有点做作的例子:
import asyncio
import contextvars
ctxvar = contextvars.ContextVar("ctx")
async def lifepsan(cap: CaptureContext, event: asyncio.Event):
with cap:
ctxvar.set("spam")
cap.sync()
event.set()
await asyncio.sleep(float("inf"))
async def endpoint():
assert ctxvar.get() == "spam"
async def main():
cap = CaptureContext()
event = asyncio.Event()
asyncio.create_task(lifepsan(cap, event))
await event.wait()
restore_context(cap.context)
task = asyncio.create_task(endpoint())
await task
asyncio.run(main())
我认为如果contextvars.Context.run
和协程一起工作的话会更好。
此功能将在Python 3.11中得到支持:https://github.com/python/cpython/issues/91150
你可以这样写:
async def main():
ctx = contextvars.copy_context()
task = asyncio.create_task(lifepsan(), context=ctx)
await task
endpoint_ctx = ctx.copy()
task = asyncio.create_task(endpoint(), context=endpoint_ctx)
await task
与此同时,在当前的Python版本中,您将需要此特性的后端口。我想不出一个好的,但有一个坏的。