Commit 58660ddc authored by Xavier Barbosa's avatar Xavier Barbosa

class injection

parent 8566a72b
Pipeline #1208 passed with stage
in 33 seconds
......@@ -189,5 +189,28 @@ Injector has a mapping interface, which allows to register arbitrary values::
services["foo"] = "yes"
assert await services["foo"] == "yes"
Injection works for class too. Injector is already dataclass ready.
These 3 examples are equivalent and performs the same::
from dataclasses import dataclass, field
from knighted import attr, annotate, KNIGHTED_NAMESPACE
@annotate
class Foo:
my_attr = attr("foo")
@dataclass
class Bar:
my_attr = attr("foo")
@dataclass
class Baz:
my_attr: Any = field(namespace={KNIGHTED_NAMESPACE:"foo"})
the my_attr will be resolved like the function way.
.. _asyncio: https://pypi.python.org/pypi/asyncio
.. _jeni: https://pypi.python.org/pypi/jeni
from .bases import Injector, annotate, attr, current_injector
from .bases import (
Injector,
annotate,
attr,
current_injector,
AnnotationError,
attr_lazy,
)
from ._version import get_versions
......
......@@ -11,8 +11,9 @@ from functools import wraps
from inspect import signature, unwrap
from itertools import chain
from types import MappingProxyType
from typing import Callable, Optional, cast
from typing import Callable, Optional, cast, Any
from weakref import WeakKeyDictionary
from dataclasses import Field, MISSING, dataclass, is_dataclass, fields
from cached_property import cached_property
......@@ -20,7 +21,8 @@ logger = logging.getLogger("knighted")
MaybeInjector = Optional["Injector"]
ANNOTATIONS: WeakKeyDictionary[Callable, "Annotation"] = WeakKeyDictionary()
TAINTED: WeakKeyDictionary[Any, "Injector"] = WeakKeyDictionary()
Missing = object()
current_injector_var: ContextVar[MaybeInjector] = ContextVar("current_injector")
......@@ -29,12 +31,24 @@ def current_injector() -> MaybeInjector:
return current_injector_var.get(None)
class AnnotationError(Exception):
...
def annotate(*pos_notes, **kw_notes):
def wrapper(func):
func = unwrap(func)
if isinstance(func, type) and pos_notes:
raise AnnotationError("Did you added services to class?")
ANNOTATIONS[func] = Annotation(func, pos_notes, kw_notes)
return func
if pos_notes and len(pos_notes) == 1 and isinstance(pos_notes[0], type):
cls = pos_notes[0]
if not is_dataclass(cls):
logger.warning("annotating a class converts it to a dataclass")
return dataclass(cls)
return cls
for arg in chain(pos_notes, kw_notes.values()):
if not isinstance(arg, str):
raise ValueError("Notes must be strings")
......@@ -42,6 +56,9 @@ def annotate(*pos_notes, **kw_notes):
return wrapper
KNIGHTED_NAMESPACE = "knighted"
class Annotation:
def __init__(self, func, pos_notes, kw_notes):
self.bind_partial = signature(func).bind_partial
......@@ -179,11 +196,23 @@ class Injector(metaclass=ABCMeta):
with self.auto():
func, *args = args # type: ignore
orig = unwrap(func)
anno = ANNOTATIONS.get(orig)
if anno:
anno = ANNOTATIONS.get(orig, Missing)
if anno is Missing and isinstance(orig, type) and is_dataclass(orig):
# late resolution of annotation
kws = {
f.name: (f.metadata or {})[KNIGHTED_NAMESPACE]
for f in fields(orig)
if KNIGHTED_NAMESPACE in (f.metadata or {})
}
anno = ANNOTATIONS[orig] = Annotation(orig, [], kw_notes=kws)
if isinstance(anno, Annotation):
return self.do_apply(func, anno, args, kwargs)
result = func(*args, **kwargs)
if isinstance(func, type):
TAINTED[result] = self
fut: asyncio.Future = asyncio.Future()
fut.set_result(func(*args, **kwargs))
fut.set_result(result)
return fut
def do_apply(self, func, anno, args, kwargs):
......@@ -227,14 +256,42 @@ class Injector(metaclass=ABCMeta):
current_injector_var.reset(token)
class attr:
def __init__(self, name):
self.service_name = name
def attr(service, *, init=True, repr=True, hash=None, compare=True, metadata=None):
metadata = (metadata or {}).copy()
metadata[KNIGHTED_NAMESPACE] = service
return Field(
default=MISSING,
default_factory=MISSING,
init=True,
repr=True,
hash=None,
compare=True,
metadata=metadata,
)
def attr_lazy(service):
return Attr(service)
class Attr:
def __init__(self, service):
self.service = service
def __get__(self, obj, objtype):
if obj is None:
raise AttributeError("attr applies to instances only")
return current_injector_var.get().get(self.service_name)
return self
fut = obj.__dict__[self.field_name] = asyncio.Future()
task = asyncio.create_task(self.load(obj))
task.add_done_callback(lambda x: fut.set_result(x.result()))
return fut
def __set_name__(self, owner, name):
self.field_name = name
async def load(self, obj):
fut = (TAINTED.get(obj) or current_injector_var.get()).get(self.service)
return await fut
def note_loop(note):
......
import pytest
from knighted import Injector, annotate, attr, AnnotationError
from typing import Any
@pytest.fixture
def services():
class MyInjector(Injector):
pass
return MyInjector()
@pytest.mark.asyncio
async def test_annotate_class(services):
@services.factory("foo")
def foo_factory():
return "I am foo"
@annotate
class Tic:
foo: Any = attr("foo")
def __call__(self):
return {"foo": self.foo}
result = await services.apply(Tic)
assert result() == {"foo": "I am foo"}
@pytest.mark.asyncio
async def test_annotate_parent_class_is_illegal(services):
with pytest.raises(AnnotationError):
@annotate("foo")
class Tic:
def __init__(self, foo):
self.foo = foo
def __call__(self):
return {"foo": self.foo}
@pytest.mark.asyncio
async def test_annotate_subclassing(services):
@services.factory("foo")
def foo_factory():
return "I am foo"
@annotate
class Tic:
foo: Any = attr("foo")
def __init__(self, foo):
self.foo = foo
def __call__(self):
return {"foo": self.foo}
class Tac(Tic):
...
result = await services.apply(Tac)
assert result() == {"foo": "I am foo"}
@pytest.mark.asyncio
async def test_annotated_class_attribute(services):
@services.factory("foo")
def foo_factory():
return "I am foo"
@annotate
class Tic:
foo: Any = attr("foo")
def __call__(self):
return {"foo": self.foo}
result = await services.apply(Tic)
assert result() == {"foo": "I am foo"}
@pytest.mark.asyncio
async def test_annotated_class_attribute_2(services):
"""With an __init__
"""
@services.factory("foo")
def foo_factory():
return "I am foo"
@annotate
class Tic:
bar: Any = attr("foo")
def __init__(self, foo, bar):
self.foo = foo
self.bar = bar
def __call__(self):
return {"foo": self.foo, "bar": self.bar}
result = await services.apply(Tic, "value")
assert result() == {"foo": "value", "bar": "I am foo"}
import pytest
from knighted import Injector, attr, attr_lazy
from typing import Any
@pytest.fixture
def services():
class MyInjector(Injector):
pass
return MyInjector()
@pytest.mark.asyncio
async def test_descriptor_1_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache: Any = attr("foo")
toto = Toto
with services.auto(), pytest.raises(Exception):
await toto.cache
@pytest.mark.asyncio
async def test_descriptor_2_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache: Any = attr_lazy("foo")
toto = Toto()
with services.auto():
cache = await toto.cache
assert cache == "I am foo"
@pytest.mark.asyncio
async def test_annotated_class_attribute(services):
@services.factory("foo")
def foo_factory():
return "I am foo"
class Tic:
foo: Any = attr_lazy("foo")
async def __call__(self):
return {"foo": await self.foo}
result = await services.apply(Tic)
assert (await result()) == {"foo": "I am foo"}
import pytest
from knighted import Injector, annotate, current_injector, attr
from knighted import Injector, annotate, current_injector
@pytest.fixture
......@@ -112,63 +112,6 @@ async def test_noauto_partial_async_async(services):
await fun()
@pytest.mark.asyncio
async def test_descriptor_1_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache = attr("foo")
toto = Toto
with services.auto(), pytest.raises(Exception):
await toto.cache
@pytest.mark.asyncio
async def test_descriptor_2_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache = attr("foo")
toto = Toto()
with services.auto():
cache = await toto.cache
assert cache == "I am foo"
@pytest.mark.asyncio
async def test_descriptor_1_not_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache = attr("foo")
toto = Toto
with pytest.raises(Exception):
toto.cache # Must fails
@pytest.mark.asyncio
async def test_descriptor_2_not_decorated(services):
@services.factory("foo")
async def foo_factory():
return "I am foo"
class Toto:
cache = attr("foo")
toto = Toto()
with pytest.raises(LookupError):
toto.cache
@pytest.mark.asyncio
async def test_partial_sync_async(services):
@services.factory("foo")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment