Commit 7a879df9 authored by xa's avatar xa

Merge branch 'unix-socket' into 'master'

Connect to unix socket



See merge request !6
parents eca1ec60 7674a9a4
Pipeline #528 passed with stages
......@@ -23,6 +23,12 @@ Usage::
client = Disque()
job_id = await client.sendjob('queue', 'body')
``client`` accepts a tcp or unix address::
client = Disque(address='127.0.0.1:7711')
client = Disque(address=('127.0.0.1', 7711))
client = Disque(address='/path/to/socket')
API Reference
-------------
......
......@@ -54,6 +54,18 @@ def render_jobs(response):
class Disque:
"""
``client`` accepts a tcp or unix address::
client = Disque(address='127.0.0.1:7711')
client = Disque(address=('127.0.0.1', 7711))
client = Disque(address='/path/to/socket')
Parameters:
client (Address): a tcp or unix address
loop (EventLoop): asyncio loop
"""
def __init__(self, address, *, loop=None):
self.current_connection = Connection(address, loop=loop)
......
import asyncio
import hiredis
import logging
from .util import parse_address, encode_command
__all__ = ['Connection', 'ConnectionError']
......@@ -23,7 +22,7 @@ class ProtocolError(ConnectionError):
class Connection:
def __init__(self, address, *, loop=None):
self.address = parse_address(address, port=7711)
self.address = parse_address(address, host='127.0.0.1', port=7711)
self.loop = loop
self.reader = None
self.writer = None
......@@ -33,12 +32,10 @@ class Connection:
async def send_command(self, *args):
await self.connect()
message = encode_command(*args)
logging.debug("REQ", message)
self.writer.write(message)
data = await self.reader.read(65536)
self.parser.feed(data)
logging.debug("RES", data)
response = self.parser.gets()
if isinstance(response, ProtocolError):
......@@ -49,9 +46,17 @@ class Connection:
return response
async def connect(self):
if not self.connected:
reader, writer = await asyncio.open_connection(*self.address,
loop=self.loop)
self.reader = reader
self.writer = writer
self.connected = True
if self.connected:
return
if self.address.proto == 'tcp':
host, port = self.address.address
future = asyncio.open_connection(host=host, port=port,
loop=self.loop)
elif self.address.proto == 'unix':
path = self.address.address
future = asyncio.open_unix_connection(path=path, loop=self.loop)
reader, writer = await future
self.reader = reader
self.writer = writer
self.connected = True
from .addresses_util import *
from itertools import zip_longest
__all__ = ['parse_address', 'encode_command']
......@@ -9,24 +10,6 @@ def grouper(n, iterable, fillvalue=None):
return zip_longest(fillvalue=fillvalue, *args)
def parse_address(address, *, host=None, port=None):
host, port = host or 'localhost', port
if isinstance(address, (list, tuple)):
host, port = address
if isinstance(address, int):
port = address
elif isinstance(address, str):
if ':' in address:
a, _, b = address.partition(':')
host = a or host
port = b or port
elif address.isdigit():
port = int(address)
else:
host = address or host
return host, int(port) if port else None
_converters = {
bytes: lambda val: val,
bytearray: lambda val: val,
......@@ -41,7 +24,8 @@ def _bytes_len(sized):
def encode_command(*args):
"""Encodes arguments into redis bulk-strings array.
"""Encodes arguments into redis bulk-strings array
Raises TypeError if any of args not of bytes, str, int or float type.
"""
buf = bytearray()
......
from functools import singledispatch
__all__ = ['Address', 'AddressError', 'parse_address']
class Address:
def __init__(self, proto, address):
self.proto = proto
self.address = address
def __eq__(self, other):
if isinstance(other, Address):
return self.proto == other.proto and self.address == other.address
def __repr__(self):
return '<Address(proto=%r, address=%r)>' % (self.proto, self.address)
class TCPAddress(Address):
proto = 'tcp'
def __init__(self, address):
self.address = address
class UnixAddress(Address):
proto = 'unix'
def __init__(self, address):
self.address = address
class AddressError(ValueError):
def __init__(self, address):
self.address = address
super().__init__('do not know how to handle %r' % [address])
@singledispatch
def parse_address(address, **kwargs):
raise AddressError(address)
@parse_address.register(Address)
def parse_addr_instance(address, **kwargs):
return address
@parse_address.register(str)
def parse_addr_str(address, *, proto=None, host=None, port=None, **kwargs):
if '://' in address:
proto, _, address = address.partition('://')
if ':' in address:
proto = proto or 'tcp'
a, _, b = address.partition(':')
host = a or host
port = b or port
address = host, int(port)
elif address.isdigit():
proto = proto or 'tcp'
port = int(address)
address = host, int(port)
elif address.startswith('/'):
proto = proto or 'unix'
else:
proto = proto or 'tcp'
host = address or host
address = host, port
if proto == 'unix':
return UnixAddress(address=address)
elif proto == 'tcp':
return TCPAddress(address=address)
@parse_address.register(int)
def parse_addr_int(address, *, host=None, **kwargs):
proto = 'tcp'
address = host, address
return Address(proto=proto, address=address)
@parse_address.register(list)
@parse_address.register(tuple)
def parse_addr_tuple(address, *, host=None, port=None, **kwargs):
proto = 'tcp'
try:
a, b = address
host = a or host
port = b or port
address = host, port
except Exception as error:
raise AddressError(address) from error
return Address(proto=proto, address=address)
import os.path
from pytest import fixture
from tempfile import TemporaryDirectory
from subprocess import Popen, PIPE, run
......@@ -16,12 +17,15 @@ class DisqueNode:
self.port = port
self.dir = dir
self.proc = None
self.socket = os.path.join(dir, 'disque.sock')
def start(self):
if not self.proc:
cmd = ["disque-server",
"--port", str(self.port),
"--dir", self.dir]
"--dir", self.dir,
"--unixsocket", self.socket,
"--unixsocketperm", "755"]
self.proc = Popen(cmd, stdout=PIPE, stderr=PIPE)
cmd = ['disque', '-p', str(self.port), 'info']
......@@ -39,7 +43,7 @@ class DisqueNode:
@property
def configuration(self):
return Configuration(port=self.port, dir=self.dir)
return Configuration(port=self.port, dir=self.dir, socket=self.socket)
@fixture(scope='function')
......
import pytest
from aiodisque.util import parse_address, Address, AddressError
ok = [
('1.2.3.4', Address(proto='tcp', address=('1.2.3.4', 7711))),
('tcp://1.2.3.4', Address(proto='tcp', address=('1.2.3.4', 7711))),
('1237.0.0.1:', Address(proto='tcp', address=('1237.0.0.1', 7711))),
(('1237.0.0.1', None), Address(proto='tcp', address=('1237.0.0.1', 7711))),
(['1237.0.0.1', None], Address(proto='tcp', address=('1237.0.0.1', 7711))),
(':', Address(proto='tcp', address=('127.0.0.1', 7711))),
(':12', Address(proto='tcp', address=('127.0.0.1', 12))),
('12', Address(proto='tcp', address=('127.0.0.1', 12))),
('errorist.xyz', Address(proto='tcp', address=('errorist.xyz', 7711))),
(12, Address(proto='tcp', address=('127.0.0.1', 12))),
('/tmp/disque.sock', Address(proto='unix', address='/tmp/disque.sock')),
('unix:///tmp/disque.sock', Address(proto='unix', address='/tmp/disque.sock')),
(Address(proto='foo', address='bar'), Address(proto='foo', address='bar')),
]
fail = [
('a',),
('a', 'b', 'c'),
['a',],
['a', 'b', 'c'],
{},
]
@pytest.mark.parametrize("input,expected", ok)
def test_parse_ok(input, expected):
assert parse_address(input, host='127.0.0.1', port=7711) == expected
@pytest.mark.parametrize("input", fail)
def test_parse_fail(input):
with pytest.raises(AddressError):
parse_address(input)
......@@ -12,6 +12,16 @@ async def test_hello(node, event_loop):
assert 'id' in response
@pytest.mark.asyncio
async def test_unix(node, event_loop):
client = Disque(node.socket, loop=event_loop)
response = await client.hello()
assert isinstance(response, dict)
assert 'format' in response
assert 'nodes' in response
assert 'id' in response
@pytest.mark.asyncio
async def test_info(node, event_loop):
client = Disque(node.port, loop=event_loop)
......
from aiodisque.util import parse_address
def test_address():
assert parse_address('1237.0.0.1:') == ('1237.0.0.1', None)
assert parse_address(('1237.0.0.1', None)) == ('1237.0.0.1', None)
assert parse_address(['1237.0.0.1', None]) == ('1237.0.0.1', None)
assert parse_address(':') == ('localhost', None)
assert parse_address(':12') == ('localhost', 12)
assert parse_address('12') == ('localhost', 12)
assert parse_address('errorist.xyz') == ('errorist.xyz', None)
assert parse_address(12) == ('localhost', 12)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment