bases.py 5.57 KB
Newer Older
xa's avatar
xa committed
1
from __future__ import annotations
xa's avatar
xa committed
2

3
import asyncio
xa's avatar
xa committed
4
import concurrent.futures
5 6
import logging
from abc import ABCMeta
xa's avatar
xa committed
7
from collections import OrderedDict, defaultdict
8
from functools import wraps
xa's avatar
xa committed
9
from inspect import signature
xa's avatar
xa committed
10 11 12
from itertools import chain
from weakref import WeakKeyDictionary

xa's avatar
xa committed
13
from cached_property import cached_property
14

xa's avatar
xa committed
15
logger = logging.getLogger("knighted")
16 17 18 19 20 21 22 23 24


class Factory:

    def __init__(self, target):
        self.target = target

    def __call__(self, note, func=None):
        def decorate(func):
xa's avatar
xa committed
25
            self.target.factories[note] = func
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
            return func
        if func:
            return decorate(func)
        return decorate


class FactoryMethod:
    """Decorator for func
    """

    def __get__(self, obj, objtype):
        target = obj or objtype
        return Factory(target)


class DataProxy:

    def __init__(self, name, type):
        self.name = name
        self.type = type

    def __get__(self, obj, objtype):
        target = obj or objtype
        if not hasattr(target, self.name):
            setattr(target, self.name, self.type())
        return getattr(target, self.name)


class CloseHandler:
    """Closes mounted services
    """

    def __init__(self, injector):
        self.injector = injector
        self.registry = WeakKeyDictionary()

    def register(self, obj, reaction=None):
        """Register callbacks that should be thrown on close.
        """
        reaction = reaction or close_reaction
Xavier Barbosa's avatar
Xavier Barbosa committed
66
        reactions = self.registry.setdefault(obj, set())
67 68 69 70 71 72
        reactions.add(reaction)

    def unregister(self, obj, reaction=None):
        """Unregister callbacks that should not be thrown on close.
        """
        if reaction:
Xavier Barbosa's avatar
Xavier Barbosa committed
73
            reactions = self.registry.setdefault(obj, set())
74 75 76 77 78 79 80 81 82
            reactions.remove(reaction)
            if not reactions:
                self.registry.pop(obj, None)
        else:
            self.registry.pop(obj, None)

    def __call__(self):
        for obj, reactions in self.registry.items():
            for reaction in reactions:
Xavier Barbosa's avatar
Xavier Barbosa committed
83
                reaction(obj)
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        self.injector.services.clear()


class Injector(metaclass=ABCMeta):
    """Collects dependencies and reads annotations to inject them.
    """

    factory = FactoryMethod()
    services = DataProxy('_services', OrderedDict)
    factories = DataProxy('_factories', OrderedDict)

    def __init__(self):
        self.services = self.__class__.services.copy()
        self.factories = self.__class__.factories.copy()
        self.reactions = defaultdict(WeakKeyDictionary)
        self.close = CloseHandler(self)

xa's avatar
xa committed
101 102 103 104
    @cached_property
    def executor(self):
        return concurrent.futures.ThreadPoolExecutor(max_workers=10)

xa's avatar
xa committed
105
    async def get(self, note):
106 107 108 109 110
        if note in self.services:
            return self.services[note]

        for fact, args in note_loop(note):
            if fact in self.factories:
xa's avatar
xa committed
111 112
                func = self.factories[fact]
                if asyncio.iscoroutinefunction(func):
xa's avatar
xa committed
113
                    instance = await func(*args)
xa's avatar
xa committed
114 115
                else:
                    loop = asyncio.get_running_loop()
xa's avatar
xa committed
116
                    instance = await loop.run_in_executor(self.executor, func, *args)
117 118 119 120 121
                logger.info('loaded service %s' % note)
                self.services[note] = instance
                return instance
        raise ValueError('%r is not defined' % note)

xa's avatar
xa committed
122
    async def apply(self, *args, **kwargs):
123
        func, *args = args
xa's avatar
xa committed
124
        response = await self.partial(func)(*args, **kwargs)
125 126 127
        return response

    def partial(self, func):
xa's avatar
xa committed
128
        """Resolves lately dependencies.
129 130 131 132 133 134

        Returns:
            callable: the service partially resolved
        """

        @wraps(func)
xa's avatar
xa committed
135
        async def wrapper(*args, **kwargs):
136
            if func in ANNOTATIONS:
xa's avatar
xa committed
137 138 139 140 141 142 143
                annotation = ANNOTATIONS[func]
                given = annotation.given(*args, **kwargs)
                to_load = {}
                for key, note in annotation.marked.items():
                    if key not in given:
                        to_load[key] = asyncio.create_task(self.get(note))
                for key, fut in to_load.items():
xa's avatar
xa committed
144
                    to_load[key] = await fut
xa's avatar
xa committed
145 146
                kwargs.update(to_load)
                result = func(*args, **kwargs)
xa's avatar
xa committed
147
                if asyncio.iscoroutine(result):
xa's avatar
xa committed
148
                    result = await result
xa's avatar
xa committed
149
                return result
xa's avatar
xa committed
150
            logger.warning('%r is not annoted', func)
151 152 153 154
            return func(*args, **kwargs)
        return wrapper


xa's avatar
xa committed
155 156 157 158 159
class Annotation:
    def __init__(self, pos_notes, kw_notes, func):
        self.pos_notes = pos_notes
        self.kw_notes = kw_notes
        self.bind_partial = signature(func).bind_partial
160

xa's avatar
xa committed
161 162 163 164 165 166 167 168 169
    @cached_property
    def marked(self):
        return self.bind_partial(*self.pos_notes, **self.kw_notes).arguments

    def given(self, *args, **kwargs):
        return list(self.bind_partial(*args, **kwargs).arguments)


ANNOTATIONS: WeakKeyDictionary[str, Annotation] = WeakKeyDictionary()
170 171 172 173 174 175 176 177 178


def close_reaction(obj):
    obj.close()


def annotate(*args, **kwargs):

    def decorate(func):
xa's avatar
xa committed
179
        ANNOTATIONS[func] = Annotation(args, kwargs, func)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        return func

    for arg in chain(args, kwargs.values()):
        if not isinstance(arg, str):
            raise ValueError('Notes must be strings')

    return decorate


def note_loop(note):
    args = note.split(':')
    results = []
    fact, *args = args
    results.append((fact, args))
    while args:
        suffix, *args = args
        fact = '%s:%s' % (fact, suffix)
        results.append((fact, args))
    for fact, args in sorted(results, reverse=True):
        yield fact, args