Commit b5833bc5 authored by xa's avatar xa

play with layers

parent 7c6bdd65
import asyncio
from aionsq.http import NSQClient
from aionsq.lookup import NSQLookup
from aionsq.tcp import NSQReader, NSQWriter
from aionsq.util import parse_addr
__all__ = ['NSQ']
class NSQ:
def __init__(self, addr, *, loop=None):
self.addr = addr
def __init__(self, addr=None, lookup=None, *, loop=None):
self.loop = loop or asyncio.get_event_loop()
if addr and lookup:
raise ValueError('mutually exclusive addr and lookup')
if addr:
addr = parse_addr(addr)
if addr.proto == 'tcp':
self.adapter = TCPAdapter(addr, loop=self.loop)
elif addr.proto == 'http':
self.adapter = HTTPAdapter(addr, loop=self.loop)
else:
raise ValueError('unkown proto %s' % addr.proto)
elif lookup:
self.adapter = LookupAdapter(lookup, loop=self.loop)
else:
raise ValueError('addr or lookup required')
@asyncio.coroutine
def publish(self, topic, message):
raise NotImplementedError()
"""Publish a message.
Parameters:
topic (str): the topic to publish to
message (str): raw message bytes
"""
client = yield from self.adapter.writer()
response = yield from client.publish(topic, message)
return response
@asyncio.coroutine
def multi_publish(self, topic, messages):
"""Publish multiple messages in one roundtrip.
Parameters:
topic (str): the topic to publish to
messages (list): raw message bytes
"""
client = yield from self.adapter.writer()
response = yield from client.multi_publish(topic, messages)
return response
def register(self, topic, channel, *, handler=None):
"""Register an handler that will be executed for each message received
Parameters:
handler (callable): the handler to add
"""
if not handler:
return lambda x: self.register(topic, channel, handler=x)
client = yield from self.adapter.reader(topic, channel)
return client.register(handler)
def close(self):
self.adapter.close()
@asyncio.coroutine
def wait_closed(self):
yield from self.adapter.wait_closed()
class Adapter:
@asyncio.coroutine
def reader(self, topic, channel):
raise NotImplementedError()
def register(self, topic, channel):
@asyncio.coroutine
def writer(self):
raise NotImplementedError()
def close(self):
raise NotImplementedError()
@asyncio.coroutine
def wait_closed(self):
raise NotImplementedError()
class HTTPAdapter(Adapter):
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._writer = None
@asyncio.coroutine
def reader(self, topic, channel):
raise AttributeError('does not implements this')
@asyncio.coroutine
def writer(self):
if not self._writer:
self._writer = NSQClient(self.addr, loop=self.loop)
return self._writer
def close(self):
if self._writer:
self._writer.close()
@asyncio.coroutine
def wait_closed(self):
if self._writer:
yield from self._writer.wait_closed()
class LookupAdapter(Adapter):
def __init__(self, addr, *, loop=None):
self.loop = loop or asyncio.get_event_loop()
self.client = NSQLookup(addr, loop=self.loop)
self._writer = None
@asyncio.coroutine
def reader(self, topic, channel):
raise Exception()
@asyncio.coroutine
def writer(self):
if not self._writer:
nodes = yield from self.client.nodes()
node = nodes['producers'][0]
addr = 'tcp://%s:%s' % (node['broadcast_address'], node['tcp_port'])
writer = NSQWriter(addr, loop=self.loop)
yield from writer.start()
self._writer = writer
return self._writer
def close(self):
if self._writer:
self._writer.close()
@asyncio.coroutine
def wait_closed(self):
if self._writer:
yield from self._writer.wait_closed()
class TCPAdapter(Adapter):
def __init__(self, addr, *, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._readers = {}
self._writer = None
@asyncio.coroutine
def reader(self, topic, channel):
key = topic, channel
if key not in self._readers:
self._readers[key] = NSQReader(self.addr, loop=self.loop)
yield from self._readers[key].start()
return self._readers[key]
@asyncio.coroutine
def writer(self):
if not self._writer:
self._writer = NSQWriter(self.addr, loop=self.loop)
yield from self._writer.start()
return self._writer
def close(self):
if self._writer:
self._writer.close()
for reader in self._readers.values():
self._writer.close()
@asyncio.coroutine
def wait_closed(self):
if self._writer:
yield from self._writer.wait_closed()
for reader in self._readers.values():
yield from self._writer.wait_closed()
import asyncio
import json
import logging
import socket
from .protocol import NSQProtocol
from aionsq.util import parse_addr
......@@ -27,6 +28,7 @@ class NSQConnection:
self._started = False
self.pending_responses = asyncio.Queue()
self.message_listeners = set()
self.log = logging.getLogger(__name__)
@asyncio.coroutine
def start(self):
......@@ -44,9 +46,9 @@ class NSQConnection:
if frame.type == 1 and frame.error in ('E_FIN_FAILED',
'E_REQ_FAILED',
'E_TOUCH_FAILED'):
self.log.warn('failure %s', frame)
self.log.warning('failure %s', frame)
return
if frame.type in (0, 1):
elif frame.type in (0, 1):
self.loop.create_task(self.pending_responses.put(frame))
elif frame.type == 2:
for listener in self.message_listeners:
......@@ -64,7 +66,13 @@ class NSQConnection:
data = json.dumps(opts)
self.protocol.send_cmd('IDENTIFY', body=data)
response = yield from self.next_response()
return response
features = json.loads(response)
self.log.info('features %s' % features)
if features['deflate']:
raise NotImplementedError('Cannot use deflate')
if features['snappy']:
raise NotImplementedError('Cannot use snappy')
return features
@asyncio.coroutine
def pub(self, topic, data):
......
import struct
import importlib
from .exceptions import MessageError
from collections import namedtuple
if importlib.find_loader('snappy'):
import snappy
if importlib.find_loader('zlib'):
import zlib
RawResponse = namedtuple('RawResponse', 'type body')
RawError = namedtuple('RawError', 'type error body')
RawMessage = namedtuple('RawMessage', 'type body timestamp attempts id')
class DeflateLayer:
def __init__(self, level):
self.level = level
def write(self, data):
return zlib.compress(data, self.level)
def read(self, data):
return zlib.decompress(data)
class SnappyLayer:
def write(self, data):
return snappy.compress(data)
def read(self, data):
return snappy.decompress(data)
class SpecLayer:
def __init__(self):
pass
def parse_protocol(self, data):
frame_type, *_ = struct.unpack('>l', data[:4])
try:
if frame_type == 0:
return RawResponse(frame_type, data[4:].decode('utf-8'))
elif frame_type == 1:
error, body = data[4:].decode('utf-8').split(None, 1)
return RawError(frame_type, error, body)
elif frame_type == 2:
timestamp = struct.unpack('>q', data[4:12])[0]
attempts = struct.unpack('>h', data[12:14])[0]
message_id = data[14:30].decode('utf-8')
body = data[30:].decode('utf-8')
return RawMessage(frame_type, body, timestamp,
attempts, message_id)
except UnicodeDecodeError as error:
raise MessageError(data) from error
def prepare_cmd(self, cmd, *params, body=None):
output = bytearray()
output += scalar(cmd)
for param in params:
output += b' '
output += scalar(param)
output += b'\n'
if body is not None:
if isinstance(body, (list, set)):
buf = bytearray()
buf += scalar_len(body)
for o in body:
a = scalar(o)
buf += scalar_len(a)
buf += a
else:
buf = scalar(body)
output += scalar_len(buf)
output += buf
return output
def scalar(obj):
orig = obj
if isinstance(obj, str):
return bytearray(obj, encoding='utf-8')
if isinstance(obj, (int, float)):
return bytearray('%s' % obj, encoding='utf-8')
if isinstance(obj, bytes):
return bytearray(obj)
if isinstance(obj, bytearray):
return obj.copy()
msg = "Don't know how to handle this %r" % orig.__class__.__name__
raise NotImplementedError(msg)
def scalar_len(obj):
return struct.pack('>l', len(obj))
import asyncio
import logging
import struct
from .exceptions import MessageError
from collections import namedtuple
from functools import singledispatch
from .layers import SpecLayer
class NSQProtocol(asyncio.Protocol):
......@@ -12,6 +10,7 @@ class NSQProtocol(asyncio.Protocol):
self.listeners = set()
self.buf = bytearray()
self.log = logging.getLogger(__name__)
self.spec = SpecLayer()
def connection_made(self, transport):
self.transport = transport
......@@ -24,13 +23,16 @@ class NSQProtocol(asyncio.Protocol):
self.process_buffer()
def process_buffer(self):
while len(self.buf) > 4:
length, *_ = struct.unpack('>l', self.buf[:4])
part = length + 4
if len(self.buf) < part:
break
data, self.buf[0:part] = self.buf[4:part], []
frame = parse_protocol(data)
frame = self.spec.parse_protocol(data)
self.log.info('got %s', frame)
self.notify_listeners(frame)
......@@ -39,95 +41,9 @@ class NSQProtocol(asyncio.Protocol):
listener(frame)
def send_cmd(self, cmd, *params, body=None):
msg = prepare_cmd(cmd, *params, body=body)
msg = self.spec.prepare_cmd(cmd, *params, body=body)
self.log.info('send %s', msg)
self.transport.write(msg)
def add_listener(self, callback):
self.listeners.add(callback)
def extend(output, data):
if not isinstance(data, (bytes, bytearray)):
data = str(data).encode('utf-8')
output.extend(data)
def prepare_cmd(cmd, *params, body=None):
output = bytearray()
extend(output, cmd)
for param in params:
extend(output, ' %s' % param)
extend(output, '\n')
if body is not None:
extend(output, format_body(body))
return output
@singledispatch
def format_body(obj):
msg = "Don't know how to handle this %r" % obj.__class__.__name__
raise NotImplementedError(msg)
@format_body.register(int)
@format_body.register(float)
def format_body_number(obj):
obj = str(obj).encode('utf-8')
out = struct.pack('>l', len(obj))
out += obj
return out
@format_body.register(str)
def format_body_str(obj):
out = bytearray()
out.extend(struct.pack('>l', len(obj)))
out.extend(obj.encode('utf-8'))
return out
@format_body.register(bytes)
def format_body_bytes(obj):
out = bytearray()
out.extend(struct.pack('>l', len(obj)))
out.extend(obj)
return out
@format_body.register(list)
@format_body.register(set)
def format_body_lists(obj):
body = bytearray()
for o in obj:
body.extend(format_body(o))
out = bytearray()
out.extend(struct.pack('>l', len(body)))
out.extend(struct.pack('>l', len(obj)))
out.extend(body)
return out
RawResponse = namedtuple('RawResponse', 'type body')
RawError = namedtuple('RawError', 'type error body')
RawMessage = namedtuple('RawMessage', 'type body timestamp attempts id')
def parse_protocol(data):
frame_type, *_ = struct.unpack('>l', data[:4])
try:
if frame_type == 0:
return RawResponse(frame_type, data[4:].decode('utf-8'))
elif frame_type == 1:
error, body = data[4:].decode('utf-8').split(None, 1)
return RawError(frame_type, error, body)
elif frame_type == 2:
timestamp = struct.unpack('>q', data[4:12])[0]
attempts = struct.unpack('>h', data[12:14])[0]
message_id = data[14:30].decode('utf-8')
body = data[30:].decode('utf-8')
return RawMessage(frame_type, body, timestamp, attempts, message_id)
except UnicodeDecodeError as error:
raise MessageError(data) from error
......@@ -32,6 +32,8 @@ def parse_addr(addr, *, proto=None, host=None):
return addr
elif isinstance(addr, str):
if addr.startswith('http://'):
proto, addr = 'http', addr[7:]
if addr.startswith('udp://'):
proto, addr = 'udp', addr[6:]
elif addr.startswith('tcp://'):
......
from aionsq.tcp.layers import DeflateLayer
def test_deflate():
layer = DeflateLayer(6)
resp = layer.write(b'foobar')
assert resp == b'x\x9cK\xcb\xcfOJ,\x02\x00\x08\xab\x02z'
resp = layer.write(bytearray('foobar', encoding='utf-8'))
assert resp == b'x\x9cK\xcb\xcfOJ,\x02\x00\x08\xab\x02z'
resp = layer.read(b'x\x9cK\xcb\xcfOJ,\x02\x00\x08\xab\x02z')
assert resp == b'foobar'
import pytest
from aionsq import NSQ
@pytest.mark.asyncio
def test_tcp(event_loop):
nsq = NSQ('tcp://127.0.0.1:4150', loop=event_loop)
yield from nsq.publish('topic2', 'msg1')
yield from nsq.multi_publish('topic2', ['msg2', 'msg3'])
nsq.close()
yield from nsq.wait_closed()
@pytest.mark.asyncio
def test_http(event_loop):
nsq = NSQ('http://127.0.0.1:4151', loop=event_loop)
yield from nsq.publish('topic2', 'msg1')
yield from nsq.multi_publish('topic2', ['msg2', 'msg3'])
nsq.close()
yield from nsq.wait_closed()
@pytest.mark.asyncio
def test_lookup(event_loop):
nsq = NSQ(lookup='http://127.0.0.1:4161', loop=event_loop)
yield from nsq.publish('topic2', 'msg1')
yield from nsq.multi_publish('topic2', ['msg2', 'msg3'])
nsq.close()
yield from nsq.wait_closed()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment