Commit 0512f0ec authored by Xavier Barbosa's avatar Xavier Barbosa

implement tcp protocol

parent 6b631ce3
......@@ -19,28 +19,15 @@ Connect directly to a nsq daemon:
.. code-block:: python
from aionsq import NSQConsumer
nsq = NSQConsumer('tcp://120.0.0.1:4567')
from aionsq.tcp import NSQClient
nsq = NSQClient(lookup='http://120.0.0.1:4567')
chan = client.subscribe('topic1', 'chan1')
@nsq.subscribe('test')
@chan.subscribe('test')
def consumer(msg):
pass
msg.success()
yield from nsq.start()
Or thru a nsq lookup:
.. code-block:: python
from aionsq import NSQConsumer
nsq = NSQConsumer(lookup='http://120.0.0.1:4567')
@nsq.subscribe('test')
def consumer(msg):
pass
yield from nsq.start()
yield from chan.start()
Request nsq lookup:
......@@ -54,7 +41,7 @@ Publish message to a nsq:
.. code-block:: python
from aionsq import NSQClient
from aionsq.http import NSQClient
client = NSQClient('http://120.0.0.1:4567')
yield from client.publish('test', 'hello world 2')
......
from .client import *
from .connection import *
from .http import *
from .lookup import *
from .protocols import *
from .tcp import *
__all__ = (client.__all__
+ connection.__all__
__all__ = (http.__all__
+ lookup.__all__
+ protocols.__all__)
+ tcp.__all__)
__version__ = '0.1'
import asyncio
__all__ = ['connect']
@asyncio.coroutine
def connect(addr):
pass
from .client import NSQClient
from .requests import RequestHandler
__all__ = ['NSQClient', 'RequestHandler']
import asyncio
import struct
from .requests import RequestHandler, ok, exc, unwrapped_json
from collections import defaultdict
__all__ = ['NSQConsumer', 'NSQClient']
class NSQConsumer:
def __init__(self, addr=None, lookup=None, *, loop=None):
if addr and lookup:
raise ValueError('addr and lookup are mutually exclusive')
self.req_handler = RequestHandler(addr)
self.lookup = lookup
self.loop = loop or asyncio.get_event_loop()
self.consumers = defaultdict(set)
def subscribe(self, topic, *, func=None):
"""Subcribe to a topic. can be used a decorator.
"""
if func is None:
return lambda x: self.subscribe(topic, func=x)
self.consumers[topic].add(func)
return func
def unsubscribe(self, topic, *, func):
"""Subcribe to a topic. can be used a decorator.
"""
self.consumers[topic].remove(func)
return func
@asyncio.coroutine
def start(self):
raise NotImplementedError()
@asyncio.coroutine
def stop(self):
raise NotImplementedError()
class NSQClient:
......@@ -79,7 +43,7 @@ class NSQClient:
return result
@asyncio.coroutine
def pub(self, topic, message):
def publish(self, topic, message):
"""Publish a message.
Parameters:
......@@ -95,7 +59,7 @@ class NSQClient:
return result
@asyncio.coroutine
def mpub(self, topic, messages):
def multi_publish(self, topic, messages):
"""Publish multiple messages in one roundtrip.
Parameters:
......
import asyncio
import atexit
from aiohttp import ClientSession, TCPConnector
from aionsq.exceptions import TopicNotFound, ChannelNotFound
from aionsq.exceptions import HTTPInternalServerError, HTTPNotFound
from .exceptions import TopicNotFound, ChannelNotFound
from .exceptions import HTTPInternalServerError, HTTPNotFound
__all__ = ['RequestHandler']
......@@ -52,7 +52,6 @@ def ok(response):
@asyncio.coroutine
def unwrapped_json(response):
print(response.headers)
if response.status == 200:
result = yield from response.json()
return result['data']
......
import asyncio
from .requests import RequestHandler, unwrapped_json, ok
from aionsq.http.requests import RequestHandler, unwrapped_json, ok
__all__ = ['NSQLookup']
......
import asyncio
__all__ = ['NSQProtocol']
class NSQProtocol(asyncio.Protocol):
def __init__(self, message, loop):
self.message = message
self.loop = loop
def connection_made(self, transport):
transport.write(self.message.encode())
print('Data sent: {!r}'.format(self.message))
def data_received(self, data):
print('Data received: {!r}'.format(data.decode()))
def connection_lost(self, exc):
print('The server closed the connection')
print('Stop the event lop')
self.loop.stop()
from .client import *
from .connection import *
from .protocol import *
__all__ = []
import asyncio
import logging
from .connection import NSQConnection
from .exceptions import NotStartedError
class NSQClient:
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._started = False
self._channels = {}
self._connection = NSQConnection(self.addr)
@property
def started(self):
return self._started
@asyncio.coroutine
def start(self):
if not self._started:
yield from self._connection.start()
yield from self._connection.identify()
self._started = True
return True
@asyncio.coroutine
def publish(self, topic, data):
if not self._started:
raise NotStartedError()
response = yield from self._connection.pub(topic, data)
return response == 'OK'
@asyncio.coroutine
def multi_publish(self, topic, data):
assert isinstance(data, (list, set))
if not self._started:
raise NotStartedError()
response = yield from self._connection.mpub(topic, data)
return response == 'OK'
def subscribe(self, topic, channel):
k = topic, channel
if k not in self._channels:
channel = NSQChannel(self.addr, topic, channel, loop=self.loop)
self._channels[k] = channel
return channel
def close(self):
if self._started:
self._connection.close()
self._started = False
def close_all(self):
for channel in self._channels.values():
channel.close()
self.close()
class NSQChannel:
def __init__(self, addr, topic, channel, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self.topic = topic
self.channel = channel
self._started = False
self._connection = NSQConnection(self.addr)
self.triggers = set()
self._msg_counter = 0
self._msg_rdy = 0
self.log = logging.getLogger(__name__)
def register(self, callback):
callback = asyncio.coroutine(callback)
self.triggers.add(callback)
return callback
@property
def started(self):
return self._started
@asyncio.coroutine
def start(self):
if not self._started:
yield from self._connection.start()
self._connection.add_message_listener(self.import_message)
yield from self._connection.sub(self.topic, self.channel)
yield from self.ask_messages(1)
self._started = True
return True
def import_message(self, frame):
msg = Message(self._connection, frame)
self.log.info('msg %s', msg)
for trigger in self.triggers:
fut = self.loop.create_task(trigger(msg))
fut.add_done_callback(self._report_errors)
self._msg_rdy -= 1
if self._msg_rdy <= 0:
self.loop.create_task(self.ask_messages(10))
def _report_errors(self, fut):
try:
fut.result()
except Exception as error:
self.log.exception(error)
@asyncio.coroutine
def ask_messages(self, count):
yield from self._connection.rdy(count)
self._msg_rdy = count
def close(self):
@asyncio.coroutine
def stop_listening():
response = yield from self._connection.cls()
return response == 'CLOSE_WAIT'
def close_connection(fut):
try:
fut.result()
except Exception as error:
self.log.exception(error)
self._connection.close()
if self._started:
task = self.loop.create_task(stop_listening)
task.add_done_callback(close_connection)
self._started = False
class Message:
def __init__(self, connection, frame):
self._connection = connection
self.frame = frame
@property
def body(self):
return self.frame.body
@property
def timestamp(self):
return int(self.frame.timestamp)
@property
def attempts(self):
return int(self.frame.attempts)
@property
def id(self):
return self.frame.id
@asyncio.coroutine
def success(self):
yield from self._connection.fin(self.id)
@asyncio.coroutine
def error(self):
yield from self._connection.req(self.id, 0)
import asyncio
import json
import socket
from .protocol import NSQProtocol
from aionsq.util import parse_addr
@asyncio.coroutine
def connect(addr, timeout=None, *, loop=None):
addr = parse_addr(addr, proto='tcp')
loop = loop or asyncio.get_event_loop()
task = loop.create_connection(NSQConnectionFactory(), *addr)
try:
transport, protocol = yield from asyncio.wait_for(task, timeout or 10.0)
transport.write(b' V2')
return transport, protocol
except asyncio.TimeoutError as error:
raise Exception('Timeout exceeded') from error
class NSQConnection:
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._started = False
self.pending_responses = asyncio.Queue()
self.message_listeners = set()
@asyncio.coroutine
def start(self):
transport, protocol = yield from connect(self.addr, loop=self.loop)
self.transport = transport
self.protocol = protocol
self.protocol.add_listener(self.import_frame)
self._started = True
def import_frame(self, frame):
if frame.type == 0 and frame.body == '_heartbeat_':
self.protocol.send_cmd('NOP')
return
if frame.type in (0, 1):
self.loop.create_task(self.pending_responses.put(frame))
elif frame.type == 2:
for listener in self.message_listeners:
listener(frame)
@asyncio.coroutine
def identify(self, **opts):
"""Identify to the NSQ.
"""
from aionsq import __version__
opts.setdefault('user_agent', 'aionsq/%s' % __version__)
opts.setdefault('client_id', socket.getfqdn())
opts.setdefault('hostname', socket.gethostname())
opts.setdefault('feature_negotiation', True)
data = json.dumps(opts)
self.protocol.send_cmd('IDENTIFY', body=data)
response = yield from self.next_response()
return response
@asyncio.coroutine
def pub(self, topic, data):
"""Publish a message to a topic
Parameters:
topic (str): a valid string (optionally having #ephemeral suffix)
message (str): raw message bytes
"""
self.protocol.send_cmd('PUB', topic, body=data)
response = yield from self.next_response()
return response
@asyncio.coroutine
def mpub(self, topic, data):
"""Publish multiple messages to a topic (atomically)
Parameters:
topic (str): a valid string (optionally having #ephemeral suffix)
messages (str): raw messages bytes
"""
self.protocol.send_cmd('MPUB', topic, body=data)
response = yield from self.next_response()
return response
@asyncio.coroutine
def sub(self, topic, channel):
"""Subscribe to a topic/channel
Parameters:
topic (str): a valid string (optionally having #ephemeral suffix)
channel (str): a valid string (optionally having #ephemeral suffix)
"""
self.protocol.send_cmd('SUB', topic, channel)
response = yield from self.next_response()
return response
@asyncio.coroutine
def rdy(self, count):
"""Update RDY state (indicate you are ready to receive N messages)
Parameters:
count (int): where 0 < count <= configured_max
"""
self.protocol.send_cmd('RDY', count)
@asyncio.coroutine
def fin(self, message_id):
"""Finish a message (indicate successful processing).
Parameters:
message_id (str): message id as 16-byte hex string
"""
self.protocol.send_cmd('FIN', message_id)
@asyncio.coroutine
def req(self, message_id, timeout):
"""Re-queue a message (indicate failure to process).
Parameters:
message_id (str): message id as 16-byte hex string
timeout (int): N where N <= configured max timeout
0 is a special case that will not defer re-queueing
"""
self.protocol.send_cmd('REQ', message_id, timeout)
@asyncio.coroutine
def touch(self, message_id):
"""Reset the timeout for an in-flight message.
Parameters:
message_id (str): message id as 16-byte hex string
"""
self.protocol.send_cmd('TOUCH', message_id)
@asyncio.coroutine
def nop(self):
"""No-op
"""
self.protocol.send_cmd('NOP')
@asyncio.coroutine
def next_response(self):
response = yield from self.pending_responses.get()
if response.type == 0:
return response.body
elif response.type == 1:
raise Exception(response.body)
def add_message_listener(self, callback):
self.message_listeners.add(callback)
def close(self):
if self._started:
self.transport.close()
self.transport = None
self.protocol = None
self._started = False
class NSQConnectionFactory:
def __call__(self):
return NSQProtocol()
class ClientError(Exception):
pass
class NotStartedError(ClientError):
pass
class MessageError(ValueError):
pass
import asyncio
import logging
import struct
from .exceptions import MessageError
from collections import namedtuple
from functools import singledispatch
class NSQProtocol(asyncio.Protocol):
def __init__(self):
self.listeners = set()
self.buf = bytearray()
self.log = logging.getLogger(__name__)
def connection_made(self, transport):
self.transport = transport
def connection_lost(self, exc):
self.transport = None
def data_received(self, data):
self.buf.extend(data)
self.process_buffer()
def process_buffer(self):
while len(self.buf) > 4:
length, *_ = struct.unpack('>l', self.buf[:4])
part = length + 4
if len(self.buf) < part:
break
data, self.buf[0:part] = self.buf[4:part], []
frame = parse_protocol(data)
self.log.info('got %s', frame)
self.notify_listeners(frame)
def notify_listeners(self, frame):
for listener in self.listeners:
listener(frame)
def send_cmd(self, cmd, *params, body=None):
msg = prepare_cmd(cmd, *params, body=body)
self.log.info('send %s', msg)
self.transport.write(msg)
def add_listener(self, callback):
self.listeners.add(callback)
def extend(output, data):
if not isinstance(data, (bytes, bytearray)):
data = str(data).encode('utf-8')
output.extend(data)
def prepare_cmd(cmd, *params, body=None):
output = bytearray()
extend(output, cmd)
for param in params:
extend(output, ' %s' % param)
extend(output, '\n')
if body is not None:
extend(output, format_body(body))
return output
@singledispatch
def format_body(obj):
msg = "Don't know how to handle this %r" % obj.__class__.__name__
raise NotImplementedError(msg)
@format_body.register(int)
@format_body.register(float)
def format_body_number(obj):
obj = str(obj).encode('utf-8')
out = struct.pack('>l', len(obj))
out += obj
return out
@format_body.register(str)
def format_body_str(obj):
out = bytearray()
out.extend(struct.pack('>l', len(obj)))
out.extend(obj.encode('utf-8'))
return out
@format_body.register(bytes)
def format_body_bytes(obj):
out = bytearray()
out.extend(struct.pack('>l', len(obj)))
out.extend(obj)
return out
@format_body.register(list)
@format_body.register(set)
def format_body_lists(obj):
body = bytearray()
for o in obj:
body.extend(format_body(o))
out = bytearray()
out.extend(struct.pack('>l', len(body)))
out.extend(struct.pack('>l', len(obj)))
out.extend(body)
return out
RawResponse = namedtuple('RawResponse', 'type body')
RawError = namedtuple('RawError', 'type body')
RawMessage = namedtuple('RawMessage', 'type body timestamp attempts id')
def parse_protocol(data):
frame_type, *_ = struct.unpack('>l', data[:4])
try:
if frame_type == 0:
return RawResponse(frame_type, data[4:].decode('utf-8'))
elif frame_type == 1:
return RawError(frame_type, data[4:].decode('utf-8'))
elif frame_type == 2:
timestamp = struct.unpack('>q', data[4:12])[0]
attempts = struct.unpack('>h', data[12:14])[0]
message_id = data[14:30].decode('utf-8')
body = data[30:].decode('utf-8')
return RawMessage(frame_type, body, timestamp, attempts, message_id)
except UnicodeDecodeError as error:
raise MessageError(data) from error
class Address(tuple):
"""Defines what is a net address.
"""
def __new__(cls, proto, host, port):
return super().__new__(cls, (host, port))
def __init__(self, proto, host, port):
self.proto = proto
self.host = host
self.port = port
def __hash__(self):
return id((self.proto, self.host, self.port))
def __eq__(self, other):
return other == (self.proto, self.host, self.port)
def __str__(self):
return '%s://%s:%s' % (self.proto, self.host, self.port)
def parse_addr(addr, *, proto=None, host=None):
"""Parses an address;
Returns:
Address: the parsed address
"""
port = None
if isinstance(addr, Address):
return addr
elif isinstance(addr, str):
if addr.startswith('udp://'):
proto, addr = 'udp', addr[6:]
elif addr.startswith('tcp://'):
proto, addr = 'tcp', addr[6:]
elif addr.startswith('unix://'):
proto, addr = 'unix', addr[7:]
a, _, b = addr.partition(':')
host = a or host
port = b or port
elif isinstance(addr, (tuple, list)):
# list is not good
a, b = addr
host = a or host
port = b or port