"""
Virtual context managers are created to replace contextlib.contextmanager
and contextlib.asynccontextmanager decorators.
They are fully typed and distinguishable by FunDIs `scan(...)` function
"""
import types
import typing
import inspect
import warnings
from dataclasses import replace
from collections.abc import Generator, AsyncGenerator
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from .scan import scan
from .types import CallableInfo
from fundi.logging import get_logger
from .exceptions import GeneratorExitedTooEarly
__all__ = ["VirtualContextProvider", "AsyncVirtualContextProvider", "virtual_context"]
logger = get_logger("virtual_context")
T = typing.TypeVar("T")
P = typing.ParamSpec("P")
F = typing.TypeVar("F", bound=types.FunctionType)
@typing.final
class _VirtualContextManager(typing.Generic[T], AbstractContextManager[T]):
"""
Virtual context manager implementation
"""
def __init__(self, generator: Generator[T, None, None], origin: types.FunctionType) -> None:
self.generator = generator
self.origin = origin
def __enter__(self) -> T: # pyright: ignore[reportMissingSuperCall, reportImplicitOverride]
try:
logger.debug("Entering %r", self.origin)
return self.generator.send(None)
except StopIteration as exc:
raise GeneratorExitedTooEarly(self.origin, self.generator) from exc
def __exit__( # pyright: ignore[reportImplicitOverride]
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool:
try:
if exc_type is not None:
logger.debug(
"Raising %s in %r: %r",
exc_type.__name__ if exc_type else "Unknown",
self.origin,
exc_value,
)
self.generator.throw(exc_type, exc_value, traceback)
else:
logger.debug("Exiting %r", self.origin)
self.generator.send(None)
except StopIteration:
logger.debug("Generator %r exited cleanly", self.origin)
except Exception as exc:
if exc is exc_value:
logger.debug(
"Generator created by %r re-raised exception %r, suppressing traceback",
self.origin,
exc_type,
)
return False
raise exc
else:
logger.warning("Generator %r did not exit properly", self.origin)
warnings.warn("Generator not exited", UserWarning)
return False
@typing.final
class _VirtualAsyncContextManager(typing.Generic[T], AbstractAsyncContextManager[T]):
"""
Virtual context manager implementation
"""
def __init__(self, generator: AsyncGenerator[T, None], origin: types.FunctionType) -> None:
self.generator = generator
self.origin = origin
async def __aenter__(self) -> T: # pyright: ignore[reportImplicitOverride]
try:
logger.debug("Entering %r", self.origin)
return await self.generator.asend(None)
except StopAsyncIteration as exc:
raise GeneratorExitedTooEarly(self.origin, self.generator) from exc
async def __aexit__( # pyright: ignore[reportImplicitOverride]
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool:
assert self.generator is not None, "Generator not initialized, call __call__ method first"
try:
if exc_type is not None:
logger.debug(
"Raising %s in %r: %r",
exc_type.__name__ if exc_type else "Unknown",
self.origin,
exc_value,
)
await self.generator.athrow(exc_type, exc_value, traceback)
else:
logger.debug("Exiting %r", self.origin)
await self.generator.asend(None)
except StopAsyncIteration:
logger.debug("Generator %r exited cleanly", self.origin)
except Exception as exc:
if exc is exc_value:
logger.debug(
"Generator created by %r re-raised exception %r, suppressing traceback",
self.origin,
exc_type,
)
return False
raise exc
else:
logger.warning("Generator %r did not exit properly", self.origin)
warnings.warn("Generator not exited", UserWarning)
return False
class VirtualContextProvider(typing.Generic[T, P]):
"""
Synchronous virtual context manager
"""
def __init__(self, function: typing.Callable[P, Generator[T, None, None]]):
self.__fundi_info__: CallableInfo[typing.Any] = replace(
scan(function, generator=False, context=True), call=self
)
self.__wrapped__: typing.Callable[P, Generator[T, None, None]] = function
def __call__(self, *args: P.args, **kwargs: P.kwargs):
logger.debug("Creating virtual context manager for %r", self.__wrapped__)
return _VirtualContextManager(self.__wrapped__(*args, **kwargs), self.__wrapped__)
class AsyncVirtualContextProvider(typing.Generic[T, P]):
"""
Asynchronous virtual context manager
"""
def __init__(self, function: typing.Callable[P, AsyncGenerator[T]]):
self.__fundi_info__: CallableInfo[typing.Any] = replace(
scan(function, generator=False, context=True), call=self
)
self.__wrapped__: typing.Callable[P, AsyncGenerator[T]] = function
def __call__(self, *args: P.args, **kwargs: P.kwargs):
logger.debug("Creating virtual context manager for %r", self.__wrapped__)
return _VirtualAsyncContextManager(self.__wrapped__(*args, **kwargs), self.__wrapped__)
@typing.overload
def virtual_context(
function: typing.Callable[P, Generator[T, None, None]],
) -> VirtualContextProvider[T, P]: ...
@typing.overload
def virtual_context(
function: typing.Callable[P, AsyncGenerator[T]],
) -> AsyncVirtualContextProvider[T, P]: ...
[docs]
def virtual_context(
function: typing.Callable[P, Generator[T, None, None] | AsyncGenerator[T]],
) -> VirtualContextProvider[T, P] | AsyncVirtualContextProvider[T, P]:
"""
Define virtual context manager using decorator
Example::
@virtual_context
def file(name: str):
file_ = open(name, "r")
try:
yield file_
finally:
file_.close()
with file("dontreadthis.txt") as f:
print(f.read())
@virtual_context
async def lock(name: str):
lock_ = locks[name]
lock_.acquire()
try:
yield
finally:
lock_.release()
async with lock("socket-send"):
await socket.send("wtf")
"""
if inspect.isasyncgenfunction(function):
return AsyncVirtualContextProvider(function)
elif inspect.isgeneratorfunction(function):
return VirtualContextProvider(function)
raise ValueError(
f"@virtual_context expects a generator or async generator function, got {function!r}"
)