Commit 8c17250e authored by Xavier Barbosa's avatar Xavier Barbosa

let auto reconnection

parent 7a879df9
Pipeline #529 passed with stages
......@@ -60,5 +60,13 @@ In addition to the changes above, it implements some async sugar:
job = await queue.get()
assert job.id == job_id
* client can reconnect automatically when a connection lost::
from aiodisque import Disque
client = Disque(auto_reconnect=True)
await client.hello()
# ... connection has been lost here...
await client.hello() # this not fails
.. _Disque: https://github.com/antirez/disque
.. _`official Disque command documentation`: https://github.com/antirez/disque#main-api
from .connections import Connection
from .connections import connect, ClosedConnectionError
from .iterators import JobsIterator
from .scanners import JobsScanner, QueuesScanner
from .util import grouper
......@@ -64,10 +64,15 @@ class Disque:
Parameters:
client (Address): a tcp or unix address
loop (EventLoop): asyncio loop
auto_reconnect (bool): automatically reconnect after connection lost
"""
def __init__(self, address, *, loop=None):
self.current_connection = Connection(address, loop=loop)
def __init__(self, address, *, auto_reconnect=None, loop=None):
self.address = address
self.loop = loop
self.auto_reconnect = auto_reconnect
self._connection = None
self._closed = False
async def addjob(self, queue, job, ms_timeout=0, *, replicate=None,
delay=None, retry=None, ttl=None,
......@@ -304,6 +309,7 @@ class Disque:
Returns:
dict
"""
response = await self.execute_command('HELLO')
result = {k: v for k, v in zip(['format', 'id', 'nodes'], response)}
nodes = []
......@@ -519,7 +525,6 @@ class Disque:
params.extend(('STATE', state))
if reply is not None:
params.extend(('REPLY', reply))
print(params)
cursor, items = await self.execute_command(*params)
if reply == 'all':
result = []
......@@ -556,4 +561,52 @@ class Disque:
Returns:
object: the server response
"""
return await self.current_connection.send_command(*args)
try:
# naively assume that connection is still ok
connection = await self.connect()
return await connection.send_command(*args)
except ClosedConnectionError:
# resend to a freshly opened connection
connection = await self.connect(force=True)
return await connection.send_command(*args)
async def connect(self, *, force=False):
"""Connect to the disque server
Parameters:
force (bool): exchange to a fresh connection
Returns:
Connection
"""
if force or not self._connection:
if self._connection:
self._connection.close()
if self._closed:
raise RuntimeError('Connection already closed')
listeners = set()
if self.auto_reconnect:
listeners.add(self.reset_connection)
connection = await connect(self.address,
loop=self.loop,
closed_listeners=listeners)
self._connection = connection
return self._connection
def close(self):
"""Close the current connection
"""
self._closed = True
if self._connection:
self._connection.close()
self._connection = None
def reset_connection(self):
"""Reset the current connection
"""
if self._connection:
self._connection.close()
self._connection = None
......@@ -2,7 +2,7 @@ import asyncio
import hiredis
from .util import parse_address, encode_command
__all__ = ['Connection', 'ConnectionError']
__all__ = ['connect', 'Connection', 'ConnectionError']
def parser():
......@@ -15,48 +15,90 @@ class ConnectionError(RuntimeError):
pass
class ClosedConnectionError(ConnectionError):
pass
class ProtocolError(ConnectionError):
pass
async def connect(address, *, loop=None, closed_listeners=None):
"""Open a connection to Disque server.
"""
address = parse_address(address, host='127.0.0.1', port=7711)
if address.proto == 'tcp':
host, port = address.address
future = asyncio.open_connection(host=host, port=port, loop=loop)
elif address.proto == 'unix':
path = address.address
future = asyncio.open_unix_connection(path=path, loop=loop)
reader, writer = await future
return Connection(reader, writer,
loop=loop,
closed_listeners=closed_listeners)
class Connection:
def __init__(self, address, *, loop=None):
self.address = parse_address(address, host='127.0.0.1', port=7711)
self.loop = loop
self.reader = None
self.writer = None
def __init__(self, reader, writer, *, loop=None, closed_listeners=None):
self._loop = loop
self._reader = reader
self._writer = writer
self.parser = parser()
self.connected = False
self._closed = False
self._closing = None
self._closed_listeners = closed_listeners or []
async def send_command(self, *args):
await self.connect()
"""Send command to server
"""
if self.closed:
raise ClosedConnectionError('closed connection')
message = encode_command(*args)
self.writer.write(message)
data = await self.reader.read(65536)
self._writer.write(message)
data = await self._reader.read(65536)
if self._reader.at_eof():
self._closing = True
self._loop.call_soon(self._do_close, None)
raise ClosedConnectionError('Half closed connection')
self.parser.feed(data)
response = self.parser.gets()
if isinstance(response, ProtocolError):
self._closing = True
self._loop.call_soon(self._do_close, response)
self.parser = parser()
raise response
if isinstance(response, Exception):
raise response
if self._reader.at_eof():
self._closing = True
self._loop.call_soon(self._do_close, None)
return response
async def connect(self):
if self.connected:
return
if self.address.proto == 'tcp':
host, port = self.address.address
future = asyncio.open_connection(host=host, port=port,
loop=self.loop)
elif self.address.proto == 'unix':
path = self.address.address
future = asyncio.open_unix_connection(path=path, loop=self.loop)
reader, writer = await future
self.reader = reader
self.writer = writer
self.connected = True
def close(self):
"""Close connection."""
self._do_close(None)
@property
def closed(self):
"""True if connection is closed."""
closed = self._closing or self._closed
if not closed and self._reader and self._reader.at_eof():
self._closing = closed = True
self._loop.call_soon(self._do_close, None)
return closed
def _do_close(self, exc):
if not self._closed:
self._closed = True
self._closing = False
self._writer.transport.close()
self._writer = None
self._reader = None
for listener in self._closed_listeners:
listener()
......@@ -28,4 +28,30 @@ Some changes must be noticed:
* commands are coroutines and thay names are lowered.
* ``async`` is a reserved word in Python, everyfields are renamed asynchronous
Other goodies
-------------
``padding`` with ``count`` ensure that iteration will returns the same number
of slots:
.. code-block:: python
from aiodisque import Disque
client = Disque()
await client.addjob('my-queue', 'job-1')
jobs = client.getjob_iter('my-queue', nohang=True, count=2, padding=True)
await for job1, job2 in jobs:
print('- job1:', job1.id, job1.body)
print('- job2 is null:', job2 is None)
``auto_reconnect`` tries to handle half-closed connection, lost and back
connection...
.. code-block:: python
from aiodisque import Disque
client = Disque(auto_reconnect=True)
.. _`original API`: https://github.com/antirez/disque#main-api
......@@ -14,18 +14,13 @@ ok = [
(12, Address(proto='tcp', address=('127.0.0.1', 12))),
('/tmp/disque.sock', Address(proto='unix', address='/tmp/disque.sock')),
('unix:///tmp/disque.sock', Address(proto='unix', address='/tmp/disque.sock')),
('unix:///foo/bar.sock', Address(proto='unix', address='/foo/bar.sock')),
(Address(proto='foo', address='bar'), Address(proto='foo', address='bar')),
]
fail = [
('a',),
('a', 'b', 'c'),
['a',],
['a', 'b', 'c'],
{},
]
fail = [('a',), ('a', 'b', 'c'), ['a'], ['a', 'b', 'c'], {}]
@pytest.mark.parametrize("input,expected", ok)
def test_parse_ok(input, expected):
......
import asyncio
import pytest
from aiodisque import Disque, ConnectionError
......@@ -12,16 +13,6 @@ async def test_hello(node, event_loop):
assert 'id' in response
@pytest.mark.asyncio
async def test_unix(node, event_loop):
client = Disque(node.socket, loop=event_loop)
response = await client.hello()
assert isinstance(response, dict)
assert 'format' in response
assert 'nodes' in response
assert 'id' in response
@pytest.mark.asyncio
async def test_info(node, event_loop):
client = Disque(node.port, loop=event_loop)
......@@ -303,3 +294,79 @@ async def test_qscan(node, event_loop):
if not cursor:
break
assert found_queues == queues
@pytest.mark.asyncio
async def test_close_connection(node, event_loop):
client = Disque(node.port, loop=event_loop)
await client.hello()
client.close()
with pytest.raises(RuntimeError):
await client.hello()
@pytest.mark.asyncio
async def test_close_reconnection(node, event_loop):
client = Disque(node.port, loop=event_loop, auto_reconnect=True)
await client.hello()
client.close()
with pytest.raises(RuntimeError):
await client.hello()
@pytest.mark.asyncio
async def test_close_autoreconnection(node, event_loop):
server = Wrapper(5678, node.port)
await server.start()
client = Disque(server.port, loop=event_loop, auto_reconnect=True)
await client.hello()
server.close()
# really really closed connection
with pytest.raises(ConnectionRefusedError):
await client.hello()
# connection get back
await server.start()
await client.hello()
# half-closed connection
server.close()
await server.start()
await client.hello()
class Wrapper:
def __init__(self, port, redirect_port, *, loop=None):
self.redirect_port = redirect_port
self.port = port
self.loop = loop
async def start(self):
self.server = await asyncio.start_server(self.handle,
'127.0.0.1',
self.port,
loop=self.loop)
sockets = await asyncio.open_connection('127.0.0.1',
self.redirect_port,
loop=self.loop)
self.reader, self.writer = sockets
def close(self):
if self.writer:
self.writer.close()
self.writer = None
if self.server:
self.server.close()
self.server = None
self.reader = None
async def handle(self, reader, writer):
data = await reader.read(10000)
self.writer.write(data)
data = await self.reader.read(10000)
writer.write(data)
await writer.drain()
writer.close()
import pytest
from aiodisque import connect, Connection, ConnectionError
from unittest.mock import Mock
@pytest.mark.asyncio
async def test_tcp(node, event_loop):
connection = await connect(node.port, loop=event_loop)
response = await connection.send_command('HELLO')
assert isinstance(response, list)
assert len(response) == 3
@pytest.mark.asyncio
async def test_unix(node, event_loop):
connection = await connect(node.socket, loop=event_loop)
assert isinstance(connection, Connection)
response = await connection.send_command('HELLO')
assert isinstance(response, list)
assert len(response) == 3
@pytest.mark.asyncio
async def test_tcp_recover(node, event_loop):
connection = await connect(node.port, loop=event_loop)
assert isinstance(connection, Connection)
response = await connection.send_command('HELLO')
assert isinstance(response, list)
assert len(response) == 3
@pytest.mark.asyncio
async def test_closed(node, event_loop):
connection = await connect(node.port, loop=event_loop)
assert not connection.closed
await connection.send_command('HELLO')
connection.close()
assert connection.closed
with pytest.raises(ConnectionError):
await connection.send_command('HELLO')
@pytest.mark.asyncio
async def test_close_callback(node, event_loop):
spy = Mock()
connection = await connect(node.port,
loop=event_loop,
closed_listeners=[spy])
assert not spy.called
assert not connection.closed
connection.close()
assert spy.called
assert connection.closed
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment