import typing
import inspect
from types import BuiltinFunctionType, FunctionType, MethodType
from collections.abc import AsyncGenerator, Awaitable, Generator
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from fundi.logging import get_logger
from fundi.types import R, CallableInfo, Parameter, TypeResolver
from fundi.util import is_configured, get_configuration, normalize_annotation
logger = get_logger("scan")
def _transform_parameter(parameter: inspect.Parameter) -> Parameter:
logger.debug("Transforming parameter %r into FunDI parameter", parameter.name)
positional_varying = parameter.kind == inspect.Parameter.VAR_POSITIONAL
positional_only = parameter.kind == inspect.Parameter.POSITIONAL_ONLY
keyword_varying = parameter.kind == inspect.Parameter.VAR_KEYWORD
keyword_only = parameter.kind == inspect.Parameter.KEYWORD_ONLY
default = parameter.default
has_default = default is not inspect.Parameter.empty
from_: CallableInfo[typing.Any] | None = None
resolve_by_type = False
if isinstance(default, CallableInfo):
logger.debug("Parameter %r is a dependency definition", parameter.name)
has_default = False
from_ = typing.cast(CallableInfo[typing.Any], default)
annotation = parameter.annotation
if isinstance(annotation, TypeResolver):
logger.debug("Parameter %r marked to resolve by type via TypeResolver", parameter.name)
annotation = annotation.annotation
resolve_by_type = True
elif typing.get_origin(annotation) is typing.Annotated and from_ is None:
args = typing.get_args(annotation)
if TypeResolver in args:
resolve_by_type = True
logger.debug("Parameter %r marked to resolve by type via FromType", parameter.name)
else:
presence: tuple[CallableInfo[typing.Any]] | tuple[()] = tuple(
filter(lambda x: isinstance(x, CallableInfo), args)
)
if presence:
logger.debug("Parameter %r is a dependency definition", parameter.name)
from_ = presence[0]
parameter_ = Parameter(
parameter.name,
annotation,
from_=from_,
default=default if has_default else None,
has_default=has_default,
resolve_by_type=resolve_by_type,
positional_varying=positional_varying,
positional_only=positional_only,
keyword_varying=keyword_varying,
keyword_only=keyword_only,
)
if from_ is not None and from_.graphhook is not None:
logger.debug(
"Calling graph hook defined for %r on parameter %r", from_.call, parameter.name
)
from_copy = from_.copy(deep=True)
from_.graphhook(from_copy, parameter_.copy())
return parameter_.copy(from_=from_copy)
return parameter_
def _is_context(call: typing.Any):
if isinstance(call, type):
return issubclass(call, AbstractContextManager)
else:
return isinstance(call, AbstractContextManager)
def _is_async_context(call: typing.Any):
if isinstance(call, type):
return issubclass(call, AbstractAsyncContextManager)
else:
return isinstance(call, AbstractAsyncContextManager)
[docs]
def scan(
call: typing.Callable[..., R],
caching: bool = True,
async_: bool | None = None,
generator: bool | None = None,
context: bool | None = None,
use_return_annotation: bool = True,
side_effects: tuple[typing.Callable[..., typing.Any], ...] = (),
) -> CallableInfo[R]:
"""
Get callable information
:param call: callable to get information from
:param caching: whether to use cached result of this callable or not
:param async_: Override "async\_" attribute value
:param generator: Override "generator" attribute value
:param context: Override "context" attribute value
:param use_return_annotation: Whether to use call's return
annotation to define it's type
:param side_effects: functions that will be injected before this dependant
:return: callable information
"""
logger.debug(
"Scanning %r (async=%s, generator=%s, context=%s, caching=%s)",
call,
async_,
generator,
context,
caching,
)
_side_effects: list[CallableInfo[typing.Any]] = []
for side_effect in side_effects:
_side_effects.append(scan(side_effect))
if hasattr(call, "__fundi_info__"):
logger.debug("Reusing cached CallableInfo for %r", call)
info = typing.cast(CallableInfo[typing.Any], getattr(call, "__fundi_info__"))
overrides: dict[str, typing.Any] = {"use_cache": caching}
if async_ is not None:
overrides["async_"] = async_
if generator is not None:
overrides["generator"] = generator
if context is not None:
overrides["context"] = context
if side_effects:
for side_effect in info.side_effects:
if side_effect in _side_effects:
continue
_side_effects.append(side_effect)
overrides["side_effects"] = tuple(_side_effects)
logger.debug(
"Overriding cached CallableInfo for %r with values: %r",
call,
list(overrides.keys()),
)
return info.copy(**overrides)
if not callable(call):
raise ValueError(f"Callable expected, got {type(call)!r}") # pyright: ignore[reportUnreachable]
truecall = call.__call__
if isinstance(call, (FunctionType, BuiltinFunctionType, MethodType, type)):
truecall = call
signature = inspect.signature(truecall)
return_: type[typing.Any] = type
if signature.return_annotation is not signature.empty:
annotation = normalize_annotation(signature.return_annotation)[0]
if not isinstance(annotation, type):
return_ = type(return_)
else:
return_ = annotation
# WARNING: over-engineered logic!! :3
_generator: bool = inspect.isgeneratorfunction(truecall)
_agenerator: bool = inspect.isasyncgenfunction(truecall)
_context: bool = _is_context(call)
_acontext: bool = _is_async_context(call)
# Getting "generator" using return typehint or __code__ flags
if generator is None:
generator = (
use_return_annotation
and (issubclass(return_, Generator) or issubclass(return_, AsyncGenerator))
) or (_generator or _agenerator)
# Getting "context" using return typehint or callable type
if context is None:
context = (
use_return_annotation
and (issubclass(return_, (AbstractContextManager, AbstractAsyncContextManager)))
) or (_context or _acontext)
# Getting "async_" using return typehint or __code__ flags or defined above variables
if async_ is None:
async_ = (
use_return_annotation
and issubclass(return_, (AsyncGenerator, AbstractAsyncContextManager, Awaitable))
) or (_agenerator or _acontext or inspect.iscoroutinefunction(truecall))
parameters = [_transform_parameter(parameter) for parameter in signature.parameters.values()]
hooks = getattr(call, "__fundi_hooks__", {})
info = CallableInfo(
call=call,
use_cache=caching,
async_=async_,
context=context,
graphhook=hooks.get("graph"),
scopehook=hooks.get("scope"),
side_effects=(),
generator=generator,
parameters=parameters,
return_annotation=signature.return_annotation,
configuration=get_configuration(call) if is_configured(call) else None,
)
try:
setattr(call, "__fundi_info__", info)
except (AttributeError, TypeError):
logger.debug("Unable to cache scan result in %r", call)
pass
return info.copy(side_effects=tuple(_side_effects))