diff --git a/faunadb/client.py b/faunadb/client.py index 65cfcb5c..4b8897b9 100644 --- a/faunadb/client.py +++ b/faunadb/client.py @@ -11,8 +11,9 @@ from faunadb.query import _wrap from faunadb.request_result import RequestResult from faunadb._json import parse_json_or_none, to_json +from faunadb.streams import Subscription -API_VERSION = "3" +API_VERSION = "4" class _LastTxnTime(object): """Wraps tracking the last transaction time supplied from the database.""" @@ -187,6 +188,30 @@ def query(self, expression, timeout_millis=None): """ return self._execute("POST", "", _wrap(expression), with_txn_time=True, query_timeout_ms=timeout_millis) + def stream(self, expression, options=None, on_start=None, on_error=None, on_version=None, on_history=None): + """ + Creates a stream Subscription to the result of the given read-only expression. When + executed. + + The subscription returned by this method does not issue any requests until + the subscription's start method is called. Make sure to + subscribe to the events of interest, otherwise the received events are simply + ignored. + + :param expression: A read-only expression. + :param options: Object that configures the stream subscription. E.g set fields to return + :param on_start: Callback for the stream's start event. + :param on_error: Callback for the stream's error event. + :param on_version: Callback for the stream's version events. + :param on_history: Callback for the stream's history_rewrite events. + """ + subscription = Subscription(self, expression, options) + subscription.on('start', on_start) + subscription.on('error', on_error) + subscription.on('version', on_version) + subscription.on('history_rewrite', on_history) + return subscription + def ping(self, scope=None, timeout=None): """ Ping FaunaDB. @@ -264,3 +289,7 @@ def _perform_request(self, action, path, data, query, headers): url = self.base_url + "/" + path req = Request(action, url, params=query, data=to_json(data), auth=self.auth, headers=headers) return self.session.send(self.session.prepare_request(req)) + + def _auth_header(self): + """Returns the HTTP authentication header""" + return "Bearer {}".format(self.auth.username) diff --git a/faunadb/streams/__init__.py b/faunadb/streams/__init__.py new file mode 100644 index 00000000..8a666df3 --- /dev/null +++ b/faunadb/streams/__init__.py @@ -0,0 +1,3 @@ +from .client import Connection +from .dispatcher import EventDispatcher +from .subscription import Subscription \ No newline at end of file diff --git a/faunadb/streams/client.py b/faunadb/streams/client.py new file mode 100644 index 00000000..9bd77a02 --- /dev/null +++ b/faunadb/streams/client.py @@ -0,0 +1,108 @@ +from time import time + +try: + #python2 + from urllib import urlencode +except ImportError: + #python3 + from urllib.parse import urlencode + +from hyper import HTTP20Connection +from faunadb._json import to_json, parse_json_or_none +from faunadb.request_result import RequestResult +from .events import parse_stream_request_result_or_none, Error +from .errors import StreamError + +VALID_FIELDS = {"diff", "prev", "document", "action"} + + +class Connection(object): + """ + The internal stream client connection interface. + This class handles the network side of a stream + subscription. + + Current limitations: + Python requests module uses HTTP1; hyper is used for HTTP/2 + """ + def __init__(self, client, expression, options): + self._client = client + self.options = options + self.conn = None + self._fields = None + if isinstance(self.options, dict): + self._fields = self.options.get("fields", None) + elif hasattr(self.options, "fields"): + self._fields = self.options.field + if isinstance(self._fields, list): + union = set(self._fields).union(VALID_FIELDS) + if union != VALID_FIELDS: + raise Exception("Valid fields options are %s, provided %s."%(VALID_FIELDS, self._fields)) + self._state = "idle" + self._query = expression + self._data = to_json(expression).encode() + try: + self.conn = HTTP20Connection( + self._client.domain, port=self._client.port, enable_push=True) + except Exception as e: + raise StreamError(e) + + def close(self): + """ + Closes the stream subscription by aborting its underlying http request. + """ + if self.conn is None: + raise StreamError('Cannot close inactive stream subscription.') + self.conn.close() + self._state = 'closed' + + def subscribe(self, on_event): + """Initiates the stream subscription.""" + if self._state != "idle": + raise StreamError('Stream subscription already started.') + try: + self._state = 'connecting' + headers = self._client.session.headers + headers["Authorization"] = self._client._auth_header() + if self._client._query_timeout_ms is not None: + headers["X-Query-Timeout"] = str(self._client._query_timeout_ms) + headers["X-Last-Seen-Txn"] = str(self._client.get_last_txn_time()) + start_time = time() + url_params = '' + if isinstance(self._fields, list): + url_params= "?%s"%(urlencode({'fields': ",".join(self._fields)})) + id = self.conn.request("POST", "/stream%s"%(url_params), body=self._data, headers=headers) + self._state = 'open' + self._event_loop(id, on_event, start_time) + except Exception as e: + if callable(on_event): + on_event(Error(e), None) + + def _event_loop(self, stream_id, on_event, start_time): + """ Event loop for the stream. """ + response = self.conn.get_response(stream_id) + if 'x-txn-time' in response.headers: + self._client.sync_last_txn_time(int(response.headers['x-txn-time'][0].decode())) + try: + for push in response.read_chunked(): # all pushes promised before response headers + raw = push.decode() + request_result = self._stream_chunk_to_request_result(response, raw, start_time, time()) + event = parse_stream_request_result_or_none(request_result) + if event is not None and hasattr(event, 'txn'): + self._client.sync_last_txn_time(int(event.txn)) + on_event(event, request_result) + if self._client.observer is not None: + self._client.observer(request_result) + except Exception as e: + self.error = e + self.close() + on_event(Error(e), None) + + def _stream_chunk_to_request_result(self, response, raw_chunk, start_time, end_time): + """ Converts a stream chunk to a RequestResult. """ + response_content = parse_json_or_none(raw_chunk) + return RequestResult( + "POST", "/stream", self._query, self._data, + raw_chunk, response_content, None, response.headers, + start_time, end_time) + diff --git a/faunadb/streams/dispatcher.py b/faunadb/streams/dispatcher.py new file mode 100644 index 00000000..db0a322a --- /dev/null +++ b/faunadb/streams/dispatcher.py @@ -0,0 +1,35 @@ +import logging + +class EventDispatcher(object): + """ + Event dispatch interface for stream subscription. + """ + def __init__(self): + self.callbacks = {} + + def on(self, event_type, callback): + """ + Subscribe to an event. + """ + if callable(callback): + self.callbacks[event_type] = callback + elif callback is not None: + raise Exception("Callback for event `%s` is not callable."%(event_type)) + + def _noop(self, event, request_result): + """ + Default callback for unregistered event types. + """ + logging.debug("Unhandled stream event %s; %s"%(event, request_result)) + pass + + def dispatch(self, event, request_result): + """ + Dispatch the given event to the appropriate listeners. + """ + fn = self.callbacks.get(event.type, None) + if fn is None: + return self._noop(event, request_result) + return fn(event) + + diff --git a/faunadb/streams/errors.py b/faunadb/streams/errors.py new file mode 100644 index 00000000..e783ba3f --- /dev/null +++ b/faunadb/streams/errors.py @@ -0,0 +1,6 @@ +from faunadb.errors import FaunaError + +class StreamError(FaunaError): + """Stream Error""" + def __init__(self, error, request_result = None): + super(StreamError, self).__init__(error, request_result) diff --git a/faunadb/streams/events.py b/faunadb/streams/events.py new file mode 100644 index 00000000..fe1db272 --- /dev/null +++ b/faunadb/streams/events.py @@ -0,0 +1,132 @@ + +from faunadb._json import parse_json_or_none +from faunadb.errors import BadRequest, PermissionDenied + + +def parse_stream_request_result_or_none(request_result): + """ + Parses a stream RequestResult into a stream Event type. + """ + event = None + parsed = request_result.response_content + if parsed is None: + return UnknownEvent(request_result) + evt_type = parsed.get('type', None) + if evt_type == "start": + event = Start(parsed) + elif evt_type is None and 'errors' in parsed: + event = Error(BadRequest(request_result)) + elif evt_type == 'error': + event = Error(parsed) + elif evt_type == 'version': + event = Version(parsed) + elif evt_type == 'history_rewrite': + event = HistoryRewrite(parsed) + else: + event = UnknownEvent(request_result) + + return event + + +class Event(object): + """ + A stream event. + """ + def __init__(self, event_type): + self.type = event_type + +class ProtocolEvent(Event): + """ + Stream protocol event. + """ + def __init__(self, event_type): + super(ProtocolEvent, self).__init__(event_type) + + +class Start(ProtocolEvent): + """ + Stream's start event. A stream subscription always begins with a start event. + Upcoming events are guaranteed to have transaction timestamps equal to or greater than + the stream's start timestamp. + + :param data: Data + :param txn: Timestamp + """ + def __init__(self, parsed): + super(Start, self).__init__('start') + self.event = parsed['event'] + self.txn = parsed['txn'] + + def __repr__(self): + return "stream:event:Start(event=%s, txn=%d)"%(self.event, self.txn) + +class Error(ProtocolEvent): + """ + An error event is fired both for client and server errors that may occur as + a result of a subscription. + """ + def __init__(self, parsed): + super(Error, self).__init__('error') + self.error = None + self.code = None + self.description = None + if isinstance(parsed, dict): + if 'event' in parsed: + self.error = parsed['event'] + if isinstance(parsed['event'], dict): + self.code = parsed['event'].get('code', None) + self.description = parsed['event'].get('description', None) + elif 'errors' in parsed: + self.error = parsed['errors'] + else: + self.error = parsed + else: + self.error = parsed + + def __repr__(self): + return "stream:event:Error(%s)"%(self.error) + +class HistoryRewrite(Event): + """ + A history rewrite event occurs upon any modifications to the history of the + subscribed document. + + :param data: Data + :param txn: Timestamp + """ + def __init__(self, parsed): + super(HistoryRewrite, self).__init__('history_rewrite') + if isinstance(parsed, dict): + self.event = parsed.get('event', None) + self.txn = parsed.get('txn') + + def __repr__(self): + return "stream:event:HistoryRewrite(event=%s, txn=%s)" % (self.event, self.txn) + +class Version(Event): + """ + A version event occurs upon any modifications to the current state of the + subscribed document. + + :param data: Data + :param txn: Timestamp + """ + def __init__(self, parsed): + super(Version, self).__init__('version') + if isinstance(parsed, dict): + self.event = parsed.get('event', None) + self.txn = parsed.get('txn') + + def __repr__(self): + return "stream:event:Version(event=%s, txn=%s)" % (self.event, self.txn) + + +class UnknownEvent(Event): + """ + Unknown stream event. + """ + def __init__(self, parsed): + super(UnknownEvent, self).__init__(None) + self.event = 'unknown' + self.event = parsed + diff --git a/faunadb/streams/subscription.py b/faunadb/streams/subscription.py new file mode 100644 index 00000000..7df4fcff --- /dev/null +++ b/faunadb/streams/subscription.py @@ -0,0 +1,35 @@ +from .client import Connection +from .dispatcher import EventDispatcher + + +class Subscription(object): + """ + A stream subscription which dispatches events received to the registered + listener functions. This class must be constructed via the FaunaClient stream + method. + """ + def __init__(self, client, expression, options=None): + self._client = Connection(client, expression, options) + self._dispatcher = EventDispatcher() + + def start(self): + """ + Initiates the underlying subscription network calls. + """ + self._client.subscribe(self._dispatcher.dispatch) + + def on(self, event_type, callback): + """ + Registers a callback for a specific event type. + """ + self._dispatcher.on(event_type, callback) + + def close(self): + """ + Stops the current subscription and closes the underlying network connection. + """ + self._client.close() + + def __repr__(self): + return "stream:Subscription(state=%s, expression=%s, options=%s)"%(self._client._state, + self._client._query,self._client._options) diff --git a/setup.py b/setup.py index b1d218f4..30960fb0 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "iso8601", "requests", "future", + "hyper" ] tests_requires = [ diff --git a/tests/test_query.py b/tests/test_query.py index f26237c1..c7fa4be1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -66,7 +66,7 @@ def test_query(self): body212 = self._q(versioned212) body3 = self._q(query.query(lambda a, b: query.concat([a, b], "/"))) - versioned3 = Query({"api_version": "3", "lambda": ["a", "b"], "expr": { + versioned3 = Query({"api_version": "4", "lambda": ["a", "b"], "expr": { "concat": [{"var": "a"}, {"var": "b"}], "separator": "/"}}) self.assertEqual(body212, versioned212) diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 00000000..516adacb --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,193 @@ +from __future__ import division +from datetime import date, datetime +from time import sleep, time +from threading import Thread + +from faunadb.errors import BadRequest, NotFound, FaunaError +from faunadb.objects import FaunaTime, Ref, SetRef, _Expr, Native, Query +from faunadb import query +from faunadb.streams import Connection, Subscription, EventDispatcher +from tests.helpers import FaunaTestCase + +def _on_unhandled_error(event): + if hasattr(event, "data") and isinstance(event.event, Exception): + raise event.event + else: + raise Exception(event) + +class StreamTest(FaunaTestCase): + @classmethod + def setUpClass(cls): + super(StreamTest, cls).setUpClass() + cls.collection_ref = cls._q(query.create_collection({"name":"stream_test_coll"}))["ref"] + + #region Helpers + + @classmethod + def _create(cls, n=0, **data): + data["n"] = n + return cls._q(query.create(cls.collection_ref, {"data": data})) + + @classmethod + def _q(cls, query_json): + return cls.client.query(query_json) + + @classmethod + def stream_sync(cls, expression, options=None, + on_start=None, on_error=None, on_version=None, + on_history=None): + if on_error is None: + on_error = _on_unhandled_error + return cls.client.stream(expression, options, + on_start, on_error, on_version, on_history) + + #endregion + + def test_stream_on_document_reference(self): + ref = self._create(None)["ref"] + stream = None + + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(isinstance(event.event, int)) + stream.close() + + stream = self.stream_sync(ref, None, on_start=on_start) + stream.start() + + def test_stream_max_open_streams(self): + m = 101 + expected = [i for i in range(m)] + actual = [] + def threadFn(n): + ref = self._create(n)["ref"] + stream = None + + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(isinstance(event.event, int)) + self._q(query.update(ref, {"data": {"k": n}})) + + def on_version(event): + self.assertEqual(event.type, 'version') + actual.append(n) + self.assertTrue(isinstance(event.event, dict)) + while(len(actual) != m): + sleep(0.1) + stream.close() + + stream = self.stream_sync(ref, None, on_start=on_start, on_version=on_version) + stream.start() + threads = [] + for i in range(m): + th = Thread(target=threadFn, args=[i]) + th.start() + threads.append(th) + for th in threads: + th.join() + actual.sort() + self.assertEqual(actual, expected) + + def test_stream_reject_non_readonly_query(self): + q = query.create_collection({"name": "c"}) + stream = None + def on_error(error): + self.assertEqual(error.type, 'error') + self.assertTrue(isinstance(error.error, BadRequest)) + self.assertEqual(error.error._get_description(), + 'Write effect in read-only query expression.') + stream.close() + stream= self.stream_sync(q, on_error=on_error) + stream.start() + + def test_stream_select_fields(self): + ref = self._create()["ref"] + stream = None + fields = {"document", "diff"} + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(isinstance(event.event, int)) + self._q(query.update(ref, {"data":{"k": "v"}})) + + def on_version(event): + self.assertEqual(event.type, 'version') + self.assertTrue(isinstance(event.event, dict)) + self.assertTrue(isinstance(event.txn, int)) + keys = set(event.event.keys()) + self.assertEqual(keys, {"document", "diff"}) + stream.close() + options = {"fields": list(fields)} + stream = self.stream_sync(ref, options, on_start=on_start, on_version=on_version) + stream.start() + + + def test_stream_update_last_txn_time(self): + ref = self._create()["ref"] + last_txn_time = self.client.get_last_txn_time() + stream = None + + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(self.client.get_last_txn_time() > last_txn_time) + #for start event, last_txn_time maybe be updated to response X-Txn-Time header + # or event.txn. What is guaranteed is the most recent is used- hence >=. + self.assertTrue(self.client.get_last_txn_time() >= event.txn) + self._q(query.update(ref, {"data": {"k": "v"}})) + + def on_version(event): + self.assertEqual(event.type, 'version') + self.assertEqual(event.txn, self.client.get_last_txn_time()) + stream.close() + + stream = self.stream_sync(ref, on_start=on_start, on_version=on_version) + stream.start() + + def test_stream_handle_request_failures(self): + stream=None + def on_error(event): + self.assertEqual(event.type, 'error') + self.assertTrue(isinstance(event.error, BadRequest)) + self.assertEqual(event.error._get_description(), + 'Expected a Document Ref or Version, got String.') + stream=self.stream_sync('invalid stream', on_error=on_error ) + stream.start() + + def test_start_active_stream(self): + ref = self._create(None)["ref"] + stream = None + + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(isinstance(event.event, int)) + self.assertRaises(FaunaError, lambda: stream.start()) + stream.close() + + stream = self.stream_sync(ref, None, on_start=on_start) + stream.start() + + + def test_stream_auth_revalidation(self): + ref = self._create()["ref"] + stream = None + + server_key = self.root_client.query( + query.create_key({"database": self.db_ref, "role": "server"})) + client = self.root_client.new_session_client( + secret=server_key["secret"]) + + def on_start(event): + self.assertEqual(event.type, 'start') + self.assertTrue(isinstance(event.event, int)) + self.root_client.query(query.delete(server_key["ref"])) + self.client.query(query.update(ref, {"data": {"k": "v"}})) + + def on_error(event): + self.assertEqual(event.type, 'error') + self.assertEqual(event.code, 'permission denied') + self.assertEqual(event.description, + 'Authorization lost during stream evaluation.') + stream.close() + + + stream = client.stream(ref, on_start=on_start, on_error=on_error) + stream.start()