Commit 3b0ee487 authored by xa's avatar xa

implement auth layer and an authd

parent 13204f45
[run]
omit = */__main__.py
[report]
exclude_lines =
pragma: no cover
def __repr__
if self.debug:
if settings.DEBUG
raise AssertionError
raise NotImplementedError
if 0:
if __name__ == .__main__.:
......@@ -44,4 +44,8 @@ Request nsq lookup:
lookup = NSQLookup('http://120.0.0.1:4567')
info = yield from lookup.info()
Start an authd::
python -m aionsq.authd
.. _nsq: http://nsq.io
from .databases import *
from .server import Application
__all__ = ['Application']
import asyncio
import logging
from . import Application, TestDB
from aionsq.util import parse_addr
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
logging.basicConfig(level=logging.DEBUG)
parser.add_argument('listen', default=':4181', nargs='?')
parser.add_argument('--db', default='test', choices={'test'})
args = parser.parse_args()
args.listen = parse_addr(args.listen, host='0.0.0.0', proto='http')
print(args)
if args.db == 'test':
db = TestDB()
else:
raise Exception('db not loaded')
app = Application(db=db)
loop = asyncio.get_event_loop()
handler = app.make_handler()
f = loop.create_server(handler, *args.listen)
srv = loop.run_until_complete(f)
print('serving on', srv.sockets[0].getsockname())
try:
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
loop.run_until_complete(handler.finish_connections(1.0))
srv.close()
loop.run_until_complete(srv.wait_closed())
loop.run_until_complete(app.finish())
loop.close()
import asyncio
class DBHandler:
@asyncio.coroutine
def match(self, remote_ip, secret, tls):
raise NotImplementedError()
class TestDB(DBHandler):
@asyncio.coroutine
def match(self, remote_ip, secret, tls):
if secret in ('allow', '-'):
return {
'ttl': 3600,
'identity': 'username',
'authorizations': [
{
'permissions': ['subscribe', 'publish'],
'topic': '.*',
'channels': ['.*']
}
]
}
else:
raise Exception('lol')
import asyncio
import json
import logging
from aiohttp import web
class Application(web.Application):
def __init__(self, *args, db, **kwargs):
super().__init__(*args, **kwargs)
self.router.add_route('GET', '/ping', self.ping_controller)
self.router.add_route('GET', '/auth', self.auth_controller)
self.log = logging.getLogger(__name__)
self.db = db
@asyncio.coroutine
def ping_controller(self, request):
return web.Response(body=b'OK')
@asyncio.coroutine
def auth_controller(self, request):
self.log.info('incoming %s %s', request, request.GET)
try:
remote_ip = request.GET['remote_ip']
secret = request.GET['secret']
tls = request.GET['tls'] == 'true'
except KeyError:
raise web.HTTPBadRequest()
try:
response = yield from self.db.match(remote_ip, secret, tls)
except Exception as error:
raise web.HTTPForbidden() from error
body = json.dumps(response).encode()
headers = {
'Content-Type': 'application/json',
}
return web.Response(body=body, headers=headers)
from .client import *
from .connection import *
from .protocol import *
from .exceptions import *
__all__ = ['NSQReader', 'NSQWriter']
......@@ -7,13 +7,18 @@ from weakref import WeakKeyDictionary, WeakSet
class TCPHandler:
def __init__(self, addr, *, loop=None):
self.addr = addr
def __init__(self, addr, *, auth_secret=None, loop=None):
"""
Parameters:
addr (str): NSQ server address.
auth_secret (str): if auth is required by NSQ server.
loop (EventLoop): asyncio event loop.
"""
self.loop = loop or asyncio.get_event_loop()
self._started = False
self._connection = NSQConnection(self.addr)
self._connection = NSQConnection(addr, auth_secret=auth_secret)
self._openened = asyncio.Lock(loop=self.loop)
self.log = logging.getLogger(__name__)
@property
def started(self):
......@@ -22,10 +27,10 @@ class TCPHandler:
@asyncio.coroutine
def start(self, **opts):
if not self._started:
yield from self._openened.acquire()
yield from self._connection.start()
yield from self._connection.identify(**opts)
yield from self._connection.connect(**opts)
yield from self.start_warmup()
yield from self._openened.acquire() # must be the last stmt
self._started = True
return True
......@@ -34,17 +39,19 @@ class TCPHandler:
pass
def close(self):
def close_connection(fut):
self._connection.close()
self._openened.release()
self.log.debug('close')
if self._started:
task = self.loop.create_task(self.pre_close())
task.add_done_callback(self._report_errors)
task.add_done_callback(close_connection)
task.add_done_callback(lambda x: self.force_close())
self._started = False
def force_close(self):
self.log.debug('force close')
self._connection.close()
self._openened.release()
@asyncio.coroutine
def wait_closed(self):
with (yield from self._openened):
......@@ -101,7 +108,6 @@ class NSQReader(TCPHandler):
self._msg_counter = 0
self._msg_rdy = 0
self._msgs = WeakSet()
self.log = logging.getLogger(__name__)
@property
def messages(self):
......
......@@ -22,13 +22,14 @@ def connect(addr, timeout=None, *, loop=None):
class NSQConnection:
def __init__(self, addr, *, loop=None):
def __init__(self, addr, *, auth_secret=None, loop=None):
self.addr = addr
self.loop = loop or asyncio.get_event_loop()
self._started = False
self.pending_responses = asyncio.Queue()
self.message_listeners = set()
self.log = logging.getLogger(__name__)
self.auth_secret = auth_secret
@asyncio.coroutine
def start(self):
......@@ -38,6 +39,16 @@ class NSQConnection:
self.protocol.add_listener(self.import_frame)
self._started = True
def close(self):
if self._started:
self.force_close()
def force_close(self):
self.transport.close()
self.transport = None
self.protocol = None
self._started = False
def import_frame(self, frame):
if frame.type == 0 and frame.body == '_heartbeat_':
self.protocol.send_cmd('NOP')
......@@ -55,9 +66,18 @@ class NSQConnection:
listener(frame)
@asyncio.coroutine
def identify(self, **opts):
def connect(self, **opts):
"""Identify to the NSQ.
"""
try:
response = yield from self._connect(**opts)
return response
except Exception as error:
self.force_close()
raise error
@asyncio.coroutine
def _connect(self, **opts):
from aionsq import __version__
opts.setdefault('user_agent', 'aionsq/%s' % __version__)
opts.setdefault('client_id', socket.getfqdn())
......@@ -79,6 +99,13 @@ class NSQConnection:
self.protocol.upgrade_snappy()
response = mini.read()
self.log.info('snappy status %s' % response)
if features['auth_required']:
if not self.auth_secret:
self.log.warning('auth required but auth_secret not set')
auth_secret = self.auth_secret or '-'
self.protocol.send_cmd('AUTH', body=auth_secret)
response = mini.read()
self.log.info('auth status %s' % response)
return features
@asyncio.coroutine
......@@ -178,13 +205,6 @@ class NSQConnection:
def add_message_listener(self, callback):
self.message_listeners.add(callback)
def close(self):
if self._started:
self.transport.close()
self.transport = None
self.protocol = None
self._started = False
class NSQConnectionFactory:
......
......@@ -8,3 +8,13 @@ class NotStartedError(ClientError):
class MessageError(ValueError):
pass
class ServerError(Exception):
pass
class AuthFailedError(ServerError):
pass
import asyncio
import logging
import struct
from . import exceptions
from .layers import SpecLayer, DeflateLayer, SnappyLayer
......@@ -108,4 +109,12 @@ class BlockingProtocol:
self.log.info('got %s', frame)
if frame.type == 0:
return frame.body
if frame.type == 1:
raise exc(frame)
raise Exception(frame)
def exc(frame):
if frame.error == 'E_AUTH_FAILED':
return exceptions.AuthFailedError(frame.body)
return exceptions.ServerError(frame.error, frame.body)
import asyncio
import pytest
from aionsq.tcp.layers import DeflateLayer, SnappyLayer
from aionsq.tcp import TCPHandler, AuthFailedError
@pytest.mark.asyncio
def test_auth_allow(event_loop):
client = TCPHandler(':4150', auth_secret='allow', loop=event_loop)
yield from client.start()
client.close()
yield from client.wait_closed()
@pytest.mark.asyncio
def test_auth_deny(event_loop):
client = TCPHandler(':4150', auth_secret='deny', loop=event_loop)
with pytest.raises(AuthFailedError):
yield from client.start()
client.close()
yield from client.wait_closed()
import aiohttp
import asyncio
import pytest
import socket
from aionsq.authd import Application, TestDB
def find_unused_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port
@pytest.mark.asyncio
def create_server(event_loop):
db = TestDB()
app = Application(db=db)
handler = app.make_handler()
port = find_unused_port()
srv = yield from event_loop.create_server(handler, '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port)
return app, srv, url, handler
@pytest.mark.asyncio
def test_ping(event_loop):
app, srv, url, handler = yield from create_server(event_loop)
response = yield from aiohttp.get(url + '/ping')
result = yield from response.text()
assert response.status == 200
assert result == 'OK'
yield from handler.finish_connections()
@pytest.mark.asyncio
def test_auth_bad(event_loop):
app, srv, url, handler = yield from create_server(event_loop)
response = yield from aiohttp.get(url + '/auth?secret=allowed')
assert response.status == 400
yield from handler.finish_connections()
@pytest.mark.asyncio
def test_auth_ok(event_loop):
app, srv, url, handler = yield from create_server(event_loop)
response = yield from aiohttp.get(url + '/auth?secret=allow&remote_ip=.&tls=false')
yield from handler.finish_connections()
assert response.status == 200
result = yield from response.json()
assert result == {
'authorizations': [{'channels': ['.*'],
'permissions': ['subscribe', 'publish'],
'topic': '.*'}],
'identity': 'username',
'ttl': 3600
}
@pytest.mark.asyncio
def test_auth_allow(event_loop):
app, srv, url, handler = yield from create_server(event_loop)
response = yield from aiohttp.get(url + '/auth?secret=allow&remote_ip=.&tls=false')
yield from handler.finish_connections()
assert response.status == 200
result = yield from response.json()
assert result == {
'authorizations': [{'channels': ['.*'],
'permissions': ['subscribe', 'publish'],
'topic': '.*'}],
'identity': 'username',
'ttl': 3600
}
@pytest.mark.asyncio
def test_auth_forbidden(event_loop):
app, srv, url, handler = yield from create_server(event_loop)
response = yield from aiohttp.get(url + '/auth?secret=deny&remote_ip=.&tls=false')
yield from handler.finish_connections()
assert response.status == 403
result = yield from response.text()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment