Commit 7c6bdd65 authored by xa's avatar xa

close sockets

parent 02fa3df4
from .client import *
from .http import *
from .lookup import *
from .tcp import *
__all__ = (http.__all__
__all__ = (client.__all__
+ http.__all__
+ lookup.__all__
+ tcp.__all__)
__version__ = '0.1'
import asyncio
__all__ = ['NSQ']
class NSQ:
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
@asyncio.coroutine
def publish(self, topic, message):
raise NotImplementedError()
@asyncio.coroutine
def multi_publish(self, topic, messages):
raise NotImplementedError()
def register(self, topic, channel):
raise NotImplementedError()
from .client import NSQClient
from .requests import RequestHandler
__all__ = ['NSQClient', 'RequestHandler']
__all__ = ['NSQClient']
......@@ -3,11 +3,31 @@ import struct
from .requests import RequestHandler, ok, exc, unwrapped_json
class NSQClient:
class HTTPHandler:
def __init__(self, addr, *, loop=None):
self.req_handler = RequestHandler(addr)
self.loop = loop or asyncio.get_event_loop()
self.req_handler = RequestHandler(addr, loop=self.loop)
self._closed = asyncio.Event(loop=self.loop)
def close(self):
if not self._closed.is_set():
self.req_handler.close()
self._closed.set()
@asyncio.coroutine
def wait_closed(self):
yield from self._closed.wait()
return True
def __repr__(self):
return '<%s(addr=%r)>' % (
self.__class__.__name__,
self.req_handler.addr
)
class NSQClient(HTTPHandler):
@asyncio.coroutine
def stats(self):
......
import asyncio
import atexit
from aiohttp import ClientSession, TCPConnector
from .exceptions import TopicNotFound, ChannelNotFound
from .exceptions import HTTPInternalServerError, HTTPNotFound
......@@ -9,11 +8,12 @@ __all__ = ['RequestHandler']
class RequestHandler:
def __init__(self, addr):
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
connector = TCPConnector(verify_ssl=False)
self.session = session = ClientSession(connector=connector)
atexit.register(session.close)
self.session = ClientSession(connector=connector, loop=self.loop)
self._closed = False
@asyncio.coroutine
def request(self, method, path, **kwargs):
......@@ -21,6 +21,12 @@ class RequestHandler:
response = yield from self.session.request(method, url, **kwargs)
return response
def close(self):
if not self._closed:
self.session.close()
self.session = None
self._closed = True
__call__ = request
......
import asyncio
from aionsq.http.requests import RequestHandler, unwrapped_json, ok
from aionsq.http.client import HTTPHandler
from aionsq.http.requests import unwrapped_json, ok
__all__ = ['NSQLookup']
class NSQLookup:
def __init__(self, addr):
self.req_handler = RequestHandler(addr)
class NSQLookup(HTTPHandler):
@asyncio.coroutine
def lookup(self, *, topic):
......@@ -115,9 +113,3 @@ class NSQLookup:
response = yield from self.req_handler('GET', path)
result = yield from unwrapped_json(response)
return result
def __repr__(self):
return '<%s(addr=%r)>' % (
self.__class__.__name__,
self.req_handler.addr
)
......@@ -5,7 +5,7 @@ from .exceptions import NotStartedError
from weakref import WeakKeyDictionary, WeakSet
class NSQClient:
class TCPHandler:
def __init__(self, addr, *, loop=None):
......@@ -13,6 +13,7 @@ class NSQClient:
self.loop = loop or asyncio.get_event_loop()
self._started = False
self._connection = NSQConnection(self.addr)
self._openened = asyncio.Lock(loop=self.loop)
@property
def started(self):
......@@ -21,6 +22,7 @@ class NSQClient:
@asyncio.coroutine
def start(self):
if not self._started:
yield from self._openened.acquire()
yield from self._connection.start()
yield from self._connection.identify()
yield from self.start_warmup()
......@@ -35,6 +37,7 @@ class NSQClient:
def close_connection(fut):
self._connection.close()
self._openened.release()
if self._started:
task = self.loop.create_task(self.pre_close())
......@@ -42,6 +45,11 @@ class NSQClient:
task.add_done_callback(close_connection)
self._started = False
@asyncio.coroutine
def wait_closed(self):
with (yield from self._openened):
return True
@asyncio.coroutine
def pre_close(self):
return None
......@@ -53,25 +61,37 @@ class NSQClient:
self.log.exception(error)
class NSQWriter(NSQClient):
class NSQWriter(TCPHandler):
@asyncio.coroutine
def publish(self, topic, data):
def publish(self, topic, message):
"""Publish a message.
Parameters:
topic (str): the topic to publish to
message (str): raw message bytes
"""
if not self._started:
raise NotStartedError()
response = yield from self._connection.pub(topic, data)
response = yield from self._connection.pub(topic, message)
return response == 'OK'
@asyncio.coroutine
def multi_publish(self, topic, data):
assert isinstance(data, (list, set))
def multi_publish(self, topic, messages):
"""Publish multiple messages in one roundtrip.
Parameters:
topic (str): the topic to publish to
messages (list): raw message bytes
"""
assert isinstance(messages, (list, set))
if not self._started:
raise NotStartedError()
response = yield from self._connection.mpub(topic, data)
response = yield from self._connection.mpub(topic, messages)
return response == 'OK'
class NSQReader(NSQClient):
class NSQReader(TCPHandler):
def __init__(self, addr, topic, channel, *, loop=None):
super().__init__(addr, loop=loop)
......@@ -101,14 +121,24 @@ class NSQReader(NSQClient):
yield from self._connection.sub(self.topic, self.channel)
yield from self.ask_messages(1)
def register(self, callback):
cb = asyncio.coroutine(callback)
self._consumers.setdefault(callback, cb)
return callback
def register(self, message_handler):
"""Register an handler that will be executed for each message received
Parameters:
message_handler (callable): the handler to add
"""
cb = asyncio.coroutine(message_handler)
self._consumers.setdefault(message_handler, cb)
return message_handler
def unregister(self, callback):
self._consumers.pop(callback, None)
return callback
def unregister(self, message_handler):
"""Unregister the handler.
Parameters:
message_handler (callable): the handler to remove
"""
self._consumers.pop(message_handler, None)
return message_handler
def import_message(self, frame):
msg = NSQMessage(self._connection, frame)
......
# debug asyncio
import os
import logging
import warnings
os.environ['PYTHONASYNCIODEBUG'] = '1'
logging.basicConfig(level=logging.DEBUG)
warnings.simplefilter("always")
......@@ -8,6 +8,8 @@ def test_ping():
client = NSQClient('http://127.0.0.1:4151')
data = yield from client.ping()
assert data is True
client.close()
yield from client.wait_closed()
@pytest.mark.asyncio
......
......@@ -8,6 +8,8 @@ def test_ping():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.ping()
assert data is True
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -15,6 +17,8 @@ def test_info():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.info()
assert data == {'version': '0.3.5'}
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -22,6 +26,8 @@ def test_lookup_ok():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.lookup(topic='test')
lookup_validator.validate(data)
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -29,6 +35,8 @@ def test_lookup_ko():
lookup = NSQLookup('http://127.0.0.1:4161')
with pytest.raises(HTTPInternalServerError):
yield from lookup.lookup(topic='foo')
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -36,6 +44,8 @@ def test_topics():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.topics()
topics_validator.validate(data)
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -43,6 +53,8 @@ def test_channels_ok():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.channels(topic='test')
assert data == {'channels': []}
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -50,6 +62,8 @@ def test_channels_ko():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.channels(topic='foo')
assert data == {'channels': []}
lookup.close()
yield from lookup.wait_closed()
@pytest.mark.asyncio
......@@ -57,6 +71,8 @@ def test_nodes():
lookup = NSQLookup('http://127.0.0.1:4161')
data = yield from lookup.nodes()
nodes_validator.validate(data)
lookup.close()
yield from lookup.wait_closed()
from jsonspec.validators import load
......
import asyncio
import pytest
from aionsq.tcp import NSQClient, NSQReader, NSQWriter
from aionsq.tcp import TCPHandler, NSQReader, NSQWriter
from aionsq.tcp import exceptions
import logging
......@@ -9,11 +9,12 @@ logging.basicConfig(level=logging.DEBUG)
@pytest.mark.asyncio
def test_client(event_loop):
client = NSQClient('tcp://127.0.0.1:4150', loop=event_loop)
client = TCPHandler('tcp://127.0.0.1:4150', loop=event_loop)
assert not client.started
yield from client.start()
assert client.started
client.close()
yield from client.wait_closed()
assert not client.started
......@@ -29,6 +30,7 @@ def test_publish(event_loop):
assert result is True
client.close()
yield from client.wait_closed()
@pytest.mark.asyncio
......@@ -43,6 +45,7 @@ def test_multi_publish(event_loop):
assert result is True
client.close()
yield from client.wait_closed()
@pytest.mark.asyncio
......@@ -57,10 +60,11 @@ def test_subscribe(event_loop):
yield from asyncio.sleep(1)
client.close()
yield from client.wait_closed()
def test_consumers_1(event_loop):
client = NSQReader(':4150', 'topic1', 'chan1', loop=event_loop)
def test_consumers_1():
client = NSQReader(':4150', 'topic1', 'chan1')
assert not client.consumers
......@@ -78,8 +82,8 @@ def test_consumers_1(event_loop):
assert consumer1 in client.consumers
def test_consumers_2(event_loop):
client = NSQReader(':4150', 'topic1', 'chan1', loop=event_loop)
def test_consumers_2():
client = NSQReader(':4150', 'topic1', 'chan1')
def consumer1(msg):
return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment