Commit 02fa3df4 authored by xa's avatar xa

refactored TCP reader / writer

parent 8638ca49
......@@ -15,19 +15,26 @@ Installation
Usage
-----
Publish message to a nsq:
.. code-block:: python
from aionsq.http import NSQWriter
writer = NSQWriter('http://120.0.0.1:4567')
yield from writer.publish('test', 'hello world 2')
Connect directly to a nsq daemon:
.. code-block:: python
from aionsq.tcp import NSQClient
nsq = NSQClient('http://120.0.0.1:4567')
chan = client.subscribe('topic1', 'chan1')
from aionsq.tcp import NSQReader
reader = NSQReader('http://120.0.0.1:4567', 'topic1', 'chan1')
@chan.subscribe('test')
@reader.subscribe('test')
def consumer(msg):
msg.success()
yield from chan.start()
yield from reader.start()
Request nsq lookup:
......@@ -37,12 +44,4 @@ Request nsq lookup:
lookup = NSQLookup('http://120.0.0.1:4567')
info = yield from lookup.info()
Publish message to a nsq:
.. code-block:: python
from aionsq.http import NSQClient
client = NSQClient('http://120.0.0.1:4567')
yield from client.publish('test', 'hello world 2')
.. _nsq: http://nsq.io
......@@ -2,4 +2,4 @@ from .client import *
from .connection import *
from .protocol import *
__all__ = []
__all__ = ['NSQReader', 'NSQWriter']
......@@ -2,6 +2,7 @@ import asyncio
import logging
from .connection import NSQConnection
from .exceptions import NotStartedError
from weakref import WeakKeyDictionary, WeakSet
class NSQClient:
......@@ -11,7 +12,6 @@ class NSQClient:
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._started = False
self._channels = {}
self._connection = NSQConnection(self.addr)
@property
......@@ -23,9 +23,38 @@ class NSQClient:
if not self._started:
yield from self._connection.start()
yield from self._connection.identify()
yield from self.start_warmup()
self._started = True
return True
@asyncio.coroutine
def start_warmup(self):
pass
def close(self):
def close_connection(fut):
self._connection.close()
if self._started:
task = self.loop.create_task(self.pre_close())
task.add_done_callback(self._report_errors)
task.add_done_callback(close_connection)
self._started = False
@asyncio.coroutine
def pre_close(self):
return None
def _report_errors(self, fut):
try:
fut.result()
except Exception as error:
self.log.exception(error)
class NSQWriter(NSQClient):
@asyncio.coroutine
def publish(self, topic, data):
if not self._started:
......@@ -41,101 +70,70 @@ class NSQClient:
response = yield from self._connection.mpub(topic, data)
return response == 'OK'
def subscribe(self, topic, channel):
k = topic, channel
if k not in self._channels:
channel = NSQChannel(self.addr, topic, channel, loop=self.loop)
self._channels[k] = channel
return channel
def close(self):
if self._started:
self._connection.close()
self._started = False
def close_all(self):
for channel in self._channels.values():
channel.close()
self.close()
class NSQChannel:
class NSQReader(NSQClient):
def __init__(self, addr, topic, channel, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
super().__init__(addr, loop=loop)
self.topic = topic
self.channel = channel
self._started = False
self._connection = NSQConnection(self.addr)
self.triggers = set()
self._consumers = WeakKeyDictionary()
self._msg_counter = 0
self._msg_rdy = 0
self._msgs = WeakSet()
self.log = logging.getLogger(__name__)
def register(self, callback):
callback = asyncio.coroutine(callback)
self.triggers.add(callback)
return callback
@property
def messages(self):
"""Returns pending messages
"""
return set(self._msgs)
@property
def started(self):
return self._started
def consumers(self):
"""Returns consumers
"""
return set(self._consumers.keys())
@asyncio.coroutine
def start(self):
if not self._started:
yield from self._connection.start()
self._connection.add_message_listener(self.import_message)
yield from self._connection.sub(self.topic, self.channel)
yield from self.ask_messages(1)
self._started = True
return True
def start_warmup(self):
self._connection.add_message_listener(self.import_message)
yield from self._connection.sub(self.topic, self.channel)
yield from self.ask_messages(1)
def register(self, callback):
cb = asyncio.coroutine(callback)
self._consumers.setdefault(callback, cb)
return callback
def unregister(self, callback):
self._consumers.pop(callback, None)
return callback
def import_message(self, frame):
msg = Message(self._connection, frame)
msg = NSQMessage(self._connection, frame)
self.log.info('msg %s', msg)
for trigger in self.triggers:
fut = self.loop.create_task(trigger(msg))
self._msgs.add(msg)
for consumer in self._consumers.values():
fut = self.loop.create_task(consumer(msg))
fut.add_done_callback(self._report_errors)
self._msg_rdy -= 1
if self._msg_rdy <= 0:
self.loop.create_task(self.ask_messages(10))
def _report_errors(self, fut):
try:
fut.result()
except Exception as error:
self.log.exception(error)
@asyncio.coroutine
def ask_messages(self, count):
yield from self._connection.rdy(count)
self._msg_rdy = count
def close(self):
@asyncio.coroutine
def stop_listening():
response = yield from self._connection.cls()
return response == 'CLOSE_WAIT'
def close_connection(fut):
try:
fut.result()
except Exception as error:
self.log.exception(error)
self._connection.close()
if self._started:
task = self.loop.create_task(stop_listening)
task.add_done_callback(close_connection)
self._started = False
@asyncio.coroutine
def pre_close(self):
response = yield from self._connection.cls()
return response == 'CLOSE_WAIT'
class Message:
class NSQMessage:
def __init__(self, connection, frame):
self._connection = connection
......
......@@ -41,6 +41,11 @@ class NSQConnection:
self.protocol.send_cmd('NOP')
return
if frame.type == 1 and frame.error in ('E_FIN_FAILED',
'E_REQ_FAILED',
'E_TOUCH_FAILED'):
self.log.warn('failure %s', frame)
return
if frame.type in (0, 1):
self.loop.create_task(self.pending_responses.put(frame))
elif frame.type == 2:
......
......@@ -111,7 +111,7 @@ def format_body_lists(obj):
RawResponse = namedtuple('RawResponse', 'type body')
RawError = namedtuple('RawError', 'type body')
RawError = namedtuple('RawError', 'type error body')
RawMessage = namedtuple('RawMessage', 'type body timestamp attempts id')
......@@ -121,7 +121,8 @@ def parse_protocol(data):
if frame_type == 0:
return RawResponse(frame_type, data[4:].decode('utf-8'))
elif frame_type == 1:
return RawError(frame_type, data[4:].decode('utf-8'))
error, body = data[4:].decode('utf-8').split(None, 1)
return RawError(frame_type, error, body)
elif frame_type == 2:
timestamp = struct.unpack('>q', data[4:12])[0]
attempts = struct.unpack('>h', data[12:14])[0]
......
import asyncio
import pytest
from aionsq.tcp import NSQClient
from aionsq.tcp import NSQClient, NSQReader, NSQWriter
from aionsq.tcp import exceptions
import logging
......@@ -19,7 +19,7 @@ def test_client(event_loop):
@pytest.mark.asyncio
def test_publish(event_loop):
client = NSQClient('tcp://127.0.0.1:4150', loop=event_loop)
client = NSQWriter('tcp://127.0.0.1:4150', loop=event_loop)
with pytest.raises(exceptions.NotStartedError):
yield from client.publish('topic1', 'msg1')
......@@ -33,7 +33,7 @@ def test_publish(event_loop):
@pytest.mark.asyncio
def test_multi_publish(event_loop):
client = NSQClient('tcp://127.0.0.1:4150', loop=event_loop)
client = NSQWriter(':4150', loop=event_loop)
with pytest.raises(exceptions.NotStartedError):
yield from client.multi_publish('topic1', ['msg1'])
......@@ -47,13 +47,51 @@ def test_multi_publish(event_loop):
@pytest.mark.asyncio
def test_subscribe(event_loop):
client = NSQClient('tcp://127.0.0.1:4150', loop=event_loop)
client = NSQReader(':4150', 'topic1', 'chan1', loop=event_loop)
chan = client.subscribe('topic1', 'chan1')
@chan.register
@client.register
def listener(msg):
yield from msg.success()
yield from chan.start()
yield from client.start()
yield from asyncio.sleep(1)
client.close()
def test_consumers_1(event_loop):
client = NSQReader(':4150', 'topic1', 'chan1', loop=event_loop)
assert not client.consumers
@client.register
def consumer1(msg):
return
assert consumer1 in client.consumers
client.unregister(consumer1)
assert consumer1 not in client.consumers
assert not client.consumers
client.register(consumer1)
assert consumer1 in client.consumers
def test_consumers_2(event_loop):
client = NSQReader(':4150', 'topic1', 'chan1', loop=event_loop)
def consumer1(msg):
return
assert not client.consumers
client.register(consumer1)
assert consumer1 in client.consumers
client.unregister(consumer1)
assert consumer1 not in client.consumers
assert not client.consumers
client.register(consumer1)
assert consumer1 in client.consumers
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment