From 8d0d424443e0473b2045198889ce0b0e655c40a9 Mon Sep 17 00:00:00 2001 From: turly221 Date: Fri, 6 Dec 2024 08:37:28 +0000 Subject: [PATCH 1/2] commit patch 27968932 --- changelog.d/8776.bugfix | 1 + synapse/federation/federation_server.py | 23 +- synapse/federation/federation_server.py.orig | 906 ++++++ synapse/federation/transport/server.py | 68 +- synapse/federation/transport/server.py.orig | 1530 +++++++++ synapse/handlers/federation.py | 10 +- synapse/handlers/federation.py.orig | 3017 ++++++++++++++++++ tests/handlers/test_federation.py | 1 - 8 files changed, 5502 insertions(+), 54 deletions(-) create mode 100644 changelog.d/8776.bugfix create mode 100644 synapse/federation/federation_server.py.orig create mode 100644 synapse/federation/transport/server.py.orig create mode 100644 synapse/handlers/federation.py.orig diff --git a/changelog.d/8776.bugfix b/changelog.d/8776.bugfix new file mode 100644 index 00000000000..dd7ebbeb865 --- /dev/null +++ b/changelog.d/8776.bugfix @@ -0,0 +1 @@ +Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 11c5d63298e..8eb970e56f2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -51,6 +51,7 @@ from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -336,7 +337,7 @@ async def _process_edu(edu_dict): TRANSACTION_CONCURRENCY_LIMIT, ) - async def on_context_state_request( + async def on_room_state_request( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) @@ -459,11 +460,12 @@ async def on_invite_request( return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, origin: str, content: JsonDict ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -492,12 +494,11 @@ async def on_make_leave_request( time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request( - self, origin: str, content: JsonDict, room_id: str - ) -> dict: + async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -693,12 +694,8 @@ async def exchange_third_party_invite( ) return ret - async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: Dict - ): - ret = await self.handler.on_exchange_third_party_invite_request( - room_id, event_dict - ) + async def on_exchange_third_party_invite_request(self, event_dict: Dict): + ret = await self.handler.on_exchange_third_party_invite_request(event_dict) return ret async def check_server_matches_acl(self, server_name: str, room_id: str): diff --git a/synapse/federation/federation_server.py.orig b/synapse/federation/federation_server.py.orig new file mode 100644 index 00000000000..11c5d63298e --- /dev/null +++ b/synapse/federation/federation_server.py.orig @@ -0,0 +1,906 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Match, + Optional, + Tuple, + Union, +) + +from canonicaljson import json +from prometheus_client import Counter, Histogram + +from twisted.internet import defer +from twisted.internet.abstract import isIPAddress +from twisted.python import failure + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import ( + AuthError, + Codes, + FederationError, + IncompatibleRoomVersionError, + NotFoundError, + SynapseError, + UnsupportedRoomVersionError, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.persistence import TransactionActions +from synapse.federation.units import Edu, Transaction +from synapse.http.endpoint import parse_server_name +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) +from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace +from synapse.logging.utils import log_function +from synapse.replication.http.federation import ( + ReplicationFederationSendEduRestServlet, + ReplicationGetQueryRestServlet, +) +from synapse.types import JsonDict, get_domain_from_id +from synapse.util import glob_to_regex, unwrapFirstError +from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.caches.response_cache import ResponseCache + +if TYPE_CHECKING: + from synapse.server import HomeServer + +# when processing incoming transactions, we try to handle multiple rooms in +# parallel, up to this limit. +TRANSACTION_CONCURRENCY_LIMIT = 10 + +logger = logging.getLogger(__name__) + +received_pdus_counter = Counter("synapse_federation_server_received_pdus", "") + +received_edus_counter = Counter("synapse_federation_server_received_edus", "") + +received_queries_counter = Counter( + "synapse_federation_server_received_queries", "", ["type"] +) + +pdu_process_time = Histogram( + "synapse_federation_server_pdu_process_time", "Time taken to process an event", +) + + +class FederationServer(FederationBase): + def __init__(self, hs): + super(FederationServer, self).__init__(hs) + + self.auth = hs.get_auth() + self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() + + self.device_handler = hs.get_device_handler() + + self._server_linearizer = Linearizer("fed_server") + self._transaction_linearizer = Linearizer("fed_txn_handler") + + self.transaction_actions = TransactionActions(self.store) + + self.registry = hs.get_federation_registry() + + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_ids_resp_cache = ResponseCache( + hs, "state_ids_resp", timeout_ms=30000 + ) + + async def on_backfill_request( + self, origin: str, room_id: str, versions: List[str], limit: int + ) -> Tuple[int, Dict[str, Any]]: + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + pdus = await self.handler.on_backfill_request( + origin, room_id, versions, limit + ) + + res = self._transaction_from_pdus(pdus).get_dict() + + return 200, res + + async def on_incoming_transaction( + self, origin: str, transaction_data: JsonDict + ) -> Tuple[int, Dict[str, Any]]: + # keep this as early as possible to make the calculated origin ts as + # accurate as possible. + request_time = self._clock.time_msec() + + transaction = Transaction(**transaction_data) + + if not transaction.transaction_id: # type: ignore + raise Exception("Transaction missing transaction_id") + + logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore + + # use a linearizer to ensure that we don't process the same transaction + # multiple times in parallel. + with ( + await self._transaction_linearizer.queue( + (origin, transaction.transaction_id) # type: ignore + ) + ): + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) + + return result + + async def _handle_incoming_transaction( + self, origin: str, transaction: Transaction, request_time: int + ) -> Tuple[int, Dict[str, Any]]: + """ Process an incoming transaction and return the HTTP response + + Args: + origin: the server making the request + transaction: incoming transaction + request_time: timestamp that the HTTP request arrived at + + Returns: + HTTP response code and body + """ + response = await self.transaction_actions.have_responded(origin, transaction) + + if response: + logger.debug( + "[%s] We've already responded to this request", + transaction.transaction_id, # type: ignore + ) + return response + + logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + + # Reject if PDU count > 50 or EDU count > 100 + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore + ): + + logger.info("Transaction PDU or EDU count too large. Returning 400") + + response = {} + await self.transaction_actions.set_response( + origin, transaction, 400, response + ) + return 400, response + + # We process PDUs and EDUs in parallel. This is important as we don't + # want to block things like to device messages from reaching clients + # behind the potentially expensive handling of PDUs. + pdu_results, _ = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._handle_pdus_in_txn, origin, transaction, request_time + ), + run_in_background(self._handle_edus_in_txn, origin, transaction), + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + response = {"pdus": pdu_results} + + logger.debug("Returning: %s", str(response)) + + await self.transaction_actions.set_response(origin, transaction, 200, response) + return 200, response + + async def _handle_pdus_in_txn( + self, origin: str, transaction: Transaction, request_time: int + ) -> Dict[str, dict]: + """Process the PDUs in a received transaction. + + Args: + origin: the server making the request + transaction: incoming transaction + request_time: timestamp that the HTTP request arrived at + + Returns: + A map from event ID of a processed PDU to any errors we should + report back to the sending server. + """ + + received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + + origin_host, _ = parse_server_name(origin) + + pdus_by_room = {} # type: Dict[str, List[EventBase]] + + for p in transaction.pdus: # type: ignore + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = request_time - int(p["age"]) + del p["age"] + + # We try and pull out an event ID so that if later checks fail we + # can log something sensible. We don't mandate an event ID here in + # case future event formats get rid of the key. + possible_event_id = p.get("event_id", "") + + # Now we get the room ID so that we can check that we know the + # version of the room. + room_id = p.get("room_id") + if not room_id: + logger.info( + "Ignoring PDU as does not have a room_id. Event ID: %s", + possible_event_id, + ) + continue + + try: + room_version = await self.store.get_room_version(room_id) + except NotFoundError: + logger.info("Ignoring PDU for unknown room_id: %s", room_id) + continue + except UnsupportedRoomVersionError as e: + # this can happen if support for a given room version is withdrawn, + # so that we still get events for said room. + logger.info("Ignoring PDU: %s", e) + continue + + event = event_from_pdu_json(p, room_version) + pdus_by_room.setdefault(room_id, []).append(event) + + pdu_results = {} + + # we can process different rooms in parallel (which is useful if they + # require callouts to other servers to fetch missing events), but + # impose a limit to avoid going too crazy with ram/cpu. + + async def process_pdus_for_room(room_id: str): + logger.debug("Processing PDUs for %s", room_id) + try: + await self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warning("Ignoring PDUs for room %s from banned server", room_id) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + with pdu_process_time.time(): + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + pdu_results[event_id] = {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), + ) + + await concurrently_execute( + process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT + ) + + return pdu_results + + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + """Process the EDUs in a received transaction. + """ + + async def _process_edu(edu_dict): + received_edus_counter.inc() + + edu = Edu( + origin=origin, + destination=self.server_name, + edu_type=edu_dict["edu_type"], + content=edu_dict["content"], + ) + await self.registry.on_edu(edu.edu_type, origin, edu.content) + + await concurrently_execute( + _process_edu, + getattr(transaction, "edus", []), + TRANSACTION_CONCURRENCY_LIMIT, + ) + + async def on_context_state_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + in_room = await self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + # we grab the linearizer to protect ourselves from servers which hammer + # us. In theory we might already have the response to this query + # in the cache so we could return it without waiting for the linearizer + # - but that's non-trivial to get right, and anyway somewhat defeats + # the point of the linearizer. + with (await self._server_linearizer.queue((origin, room_id))): + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) + ) + + room_version = await self.store.get_room_version_id(room_id) + resp["room_version"] = room_version + + return 200, resp + + async def on_state_ids_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: + if not event_id: + raise NotImplementedError("Specify an event") + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + in_room = await self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + resp = await self._state_ids_resp_cache.wrap( + (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, + ) + + return 200, resp + + async def _on_state_ids_request_compute(self, room_id, event_id): + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} + + async def _on_context_state_request_compute( + self, room_id: str, event_id: str + ) -> Dict[str, list]: + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + + return { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } + + async def on_pdu_request( + self, origin: str, event_id: str + ) -> Tuple[int, Union[JsonDict, str]]: + pdu = await self.handler.get_persisted_pdu(origin, event_id) + + if pdu: + return 200, self._transaction_from_pdus([pdu]).get_dict() + else: + return 404, "" + + async def on_query_request( + self, query_type: str, args: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: + received_queries_counter.labels(query_type).inc() + resp = await self.registry.on_query(query_type, args) + return 200, resp + + async def on_make_join_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Any]: + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + room_version = await self.store.get_room_version_id(room_id) + if room_version not in supported_versions: + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) + raise IncompatibleRoomVersionError(room_version=room_version) + + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) + time_now = self._clock.time_msec() + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + + async def on_invite_request( + self, origin: str, content: JsonDict, room_version_id: str + ) -> Dict[str, Any]: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + raise SynapseError( + 400, + "Homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + + pdu = event_from_pdu_json(content, room_version) + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version) + time_now = self._clock.time_msec() + return {"event": ret_pdu.get_pdu_json(time_now)} + + async def on_send_join_request( + self, origin: str, content: JsonDict, room_id: str + ) -> Dict[str, Any]: + logger.debug("on_send_join_request: content: %s", content) + + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + res_pdus = await self.handler.on_send_join_request(origin, pdu) + time_now = self._clock.time_msec() + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + } + + async def on_make_leave_request( + self, origin: str, room_id: str, user_id: str + ) -> Dict[str, Any]: + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) + + room_version = await self.store.get_room_version_id(room_id) + + time_now = self._clock.time_msec() + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: + logger.debug("on_send_leave_request: content: %s", content) + + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + await self.handler.on_send_leave_request(origin, pdu) + return {} + + async def on_event_auth( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + time_now = self._clock.time_msec() + auth_pdus = await self.handler.on_event_auth(event_id) + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} + return 200, res + + @log_function + async def on_query_client_keys( + self, origin: str, content: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: + return await self.on_query_request("client_keys", content) + + async def on_query_user_devices( + self, origin: str, user_id: str + ) -> Tuple[int, Dict[str, Any]]: + keys = await self.device_handler.on_federation_query_user_devices(user_id) + return 200, keys + + @trace + async def on_claim_client_keys( + self, origin: str, content: JsonDict + ) -> Dict[str, Any]: + query = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + + log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) + results = await self.store.claim_e2e_one_time_keys(query) + + json_result = {} # type: Dict[str, Dict[str, dict]] + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_str in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_str) + } + + logger.info( + "Claimed one-time-keys: %s", + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() + ) + ), + ) + + return {"one_time_keys": json_result} + + async def on_get_missing_events( + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> Dict[str, list]: + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + logger.debug( + "on_get_missing_events: earliest_events: %r, latest_events: %r," + " limit: %d", + earliest_events, + latest_events, + limit, + ) + + missing_events = await self.handler.on_get_missing_events( + origin, room_id, earliest_events, latest_events, limit + ) + + if len(missing_events) < 5: + logger.debug( + "Returning %d events: %r", len(missing_events), missing_events + ) + else: + logger.debug("Returning %d events", len(missing_events)) + + time_now = self._clock.time_msec() + + return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + + @log_function + async def on_openid_userinfo(self, token: str) -> Optional[str]: + ts_now_ms = self._clock.time_msec() + return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) + + def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + """Returns a new Transaction containing the given PDUs suitable for + transmission. + """ + time_now = self._clock.time_msec() + pdus = [p.get_pdu_json(time_now) for p in pdu_list] + return Transaction( + origin=self.server_name, + pdus=pdus, + origin_server_ts=int(time_now), + destination=None, + ) + + async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: + """ Process a PDU received in a federation /send/ transaction. + + If the event is invalid, then this method throws a FederationError. + (The error will then be logged and sent back to the sender (which + probably won't do anything with it), and other events in the + transaction will be processed as normal). + + It is likely that we'll then receive other events which refer to + this rejected_event in their prev_events, etc. When that happens, + we'll attempt to fetch the rejected event again, which will presumably + fail, so those second-generation events will also get rejected. + + Eventually, we get to the point where there are more than 10 events + between any new events and the original rejected event. Since we + only try to backfill 10 events deep on received pdu, we then accept the + new event, possibly introducing a discontinuity in the DAG, with new + forward extremities, so normal service is approximately returned, + until we try to backfill across the discontinuity. + + Args: + origin: server which sent the pdu + pdu: received pdu + + Raises: FederationError if the signatures / hash do not match, or + if the event was unacceptable for any other reason (eg, too large, + too many prev_events, couldn't find the prev_events) + """ + # check that it's actually being sent from a valid destination to + # workaround bug #1753 in 0.18.5 and 0.18.6 + if origin != get_domain_from_id(pdu.sender): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893. This is also true for some third party + # invites). + if not ( + pdu.type == "m.room.member" + and pdu.content + and pdu.content.get("membership", None) + in (Membership.JOIN, Membership.INVITE) + ): + logger.info( + "Discarding PDU %s from invalid origin %s", pdu.event_id, origin + ) + return + else: + logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) + + # We've already checked that we know the room version by this point + room_version = await self.store.get_room_version(pdu.room_id) + + # Check signature. + try: + pdu = await self._check_sigs_and_hash(room_version, pdu) + except SynapseError as e: + raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) + + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + + def __str__(self): + return "" % self.server_name + + async def exchange_third_party_invite( + self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict + ): + ret = await self.handler.exchange_third_party_invite( + sender_user_id, target_user_id, room_id, signed + ) + return ret + + async def on_exchange_third_party_invite_request( + self, room_id: str, event_dict: Dict + ): + ret = await self.handler.on_exchange_third_party_invite_request( + room_id, event_dict + ) + return ret + + async def check_server_matches_acl(self, server_name: str, room_id: str): + """Check if the given server is allowed by the server ACLs in the room + + Args: + server_name: name of server, *without any port part* + room_id: ID of the room to check + + Raises: + AuthError if the server does not match the ACL + """ + state_ids = await self.store.get_current_state_ids(room_id) + acl_event_id = state_ids.get((EventTypes.ServerACL, "")) + + if not acl_event_id: + return + + acl_event = await self.store.get_event(acl_event_id) + if server_matches_acl_event(server_name, acl_event): + return + + raise AuthError(code=403, msg="Server is banned from room") + + +def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: + """Check if the given server is allowed by the ACL event + + Args: + server_name: name of server, without any port part + acl_event: m.room.server_acl event + + Returns: + True if this server is allowed by the ACLs + """ + logger.debug("Checking %s against acl %s", server_name, acl_event.content) + + # first of all, check if literal IPs are blocked, and if so, whether the + # server name is a literal IP + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warning("Ignoring non-bool allow_ip_literals flag") + allow_ip_literals = True + if not allow_ip_literals: + # check for ipv6 literals. These start with '['. + if server_name[0] == "[": + return False + + # check for ipv4 literals. We can just lift the routine from twisted. + if isIPAddress(server_name): + return False + + # next, check the deny list + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warning("Ignoring non-list deny ACL %s", deny) + deny = [] + for e in deny: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched deny rule %s", server_name, e) + return False + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warning("Ignoring non-list allow ACL %s", allow) + allow = [] + for e in allow: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched allow rule %s", server_name, e) + return True + + # everything else should be rejected. + # logger.info("%s fell through", server_name) + return False + + +def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: + if not isinstance(acl_entry, str): + logger.warning( + "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + ) + return False + regex = glob_to_regex(acl_entry) + return regex.match(server_name) + + +class FederationHandlerRegistry(object): + """Allows classes to register themselves as handlers for a given EDU or + query type for incoming federation traffic. + """ + + def __init__(self, hs: "HomeServer"): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + self._instance_name = hs.get_instance_name() + + # These are safe to load in monolith mode, but will explode if we try + # and use them. However we have guards before we use them to ensure that + # we don't route to ourselves, and in monolith mode that will always be + # the case. + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + self.edu_handlers = ( + {} + ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] + self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + + # Map from type to instance name that we should route EDU handling to. + self._edu_type_to_instance = {} # type: Dict[str, str] + + def register_edu_handler( + self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + ): + """Sets the handler callable that will be used to handle an incoming + federation EDU of the given type. + + Args: + edu_type: The type of the incoming EDU to register handler for + handler: A callable invoked on incoming EDU + of the given type. The arguments are the origin server name and + the EDU contents. + """ + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + logger.info("Registering federation EDU handler for %r", edu_type) + + self.edu_handlers[edu_type] = handler + + def register_query_handler( + self, query_type: str, handler: Callable[[dict], defer.Deferred] + ): + """Sets the handler callable that will be used to handle an incoming + federation query of the given type. + + Args: + query_type: Category name of the query, which should match + the string used by make_query. + handler: Invoked to handle + incoming queries of this type. The return will be yielded + on and the result used as the response to the query request. + """ + if query_type in self.query_handlers: + raise KeyError("Already have a Query handler for %s" % (query_type,)) + + logger.info("Registering federation query handler for %r", query_type) + + self.query_handlers[query_type] = handler + + def register_instance_for_edu(self, edu_type: str, instance_name: str): + """Register that the EDU handler is on a different instance than master. + """ + self._edu_type_to_instance[edu_type] = instance_name + + async def on_edu(self, edu_type: str, origin: str, content: dict): + if not self.config.use_presence and edu_type == "m.presence": + return + + # Check if we have a handler on this instance + handler = self.edu_handlers.get(edu_type) + if handler: + with start_active_span_from_edu(content, "handle_edu"): + try: + await handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) + return + + # Check if we can route it somewhere else that isn't us + route_to = self._edu_type_to_instance.get(edu_type, "master") + if route_to != self._instance_name: + try: + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) + return + + # Oh well, let's just log and move on. + logger.warning("No handler registered for EDU type %s", edu_type) + + async def on_query(self, query_type: str, args: dict): + handler = self.query_handlers.get(query_type) + if handler: + return await handler(args) + + # Check if we can route it somewhere else that isn't us + if self._instance_name == "master": + return await self._get_query_client(query_type=query_type, args=args) + + # Uh oh, no handler! Let's raise an exception so the request returns an + # error. + logger.warning("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e111aa9026..55ab6f7f9d9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -439,13 +439,13 @@ async def on_GET(self, origin, content, query, event_id): class FederationStateV1Servlet(BaseFederationServlet): - PATH = "/state/(?P[^/]*)/?" + PATH = "/state/(?P[^/]*)/?" - # This is when someone asks for all data for a given context. - async def on_GET(self, origin, content, query, context): - return await self.handler.on_context_state_request( + # This is when someone asks for all data for a given room. + async def on_GET(self, origin, content, query, room_id): + return await self.handler.on_room_state_request( origin, - context, + room_id, parse_string_from_args(query, "event_id", None, required=False), ) @@ -462,16 +462,16 @@ async def on_GET(self, origin, content, query, room_id): class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/(?P[^/]*)/?" + PATH = "/backfill/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, context): + async def on_GET(self, origin, content, query, room_id): versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: return 400, {"error": "Did not include limit param"} - return await self.handler.on_backfill_request(origin, context, versions, limit) + return await self.handler.on_backfill_request(origin, room_id, versions, limit) class FederationQueryServlet(BaseFederationServlet): @@ -486,9 +486,9 @@ async def on_GET(self, origin, content, query, query_type): class FederationMakeJoinServlet(BaseFederationServlet): - PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" + PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, _content, query, context, user_id): + async def on_GET(self, origin, _content, query, room_id, user_id): """ Args: origin (unicode): The authenticated server_name of the calling server @@ -510,16 +510,16 @@ async def on_GET(self, origin, _content, query, context, user_id): supported_versions = ["1"] content = await self.handler.on_make_join_request( - origin, context, user_id, supported_versions=supported_versions + origin, room_id, user_id, supported_versions=supported_versions ) return 200, content class FederationMakeLeaveServlet(BaseFederationServlet): - PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" + PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, context, user_id): - content = await self.handler.on_make_leave_request(origin, context, user_id) + async def on_GET(self, origin, content, query, room_id, user_id): + content = await self.handler.on_make_leave_request(origin, room_id, user_id) return 200, content @@ -527,7 +527,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, (200, content) @@ -537,43 +537,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, content class FederationEventAuthServlet(BaseFederationServlet): - PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" + PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, context, event_id): - return await self.handler.on_event_auth(origin, context, event_id) + async def on_GET(self, origin, content, query, room_id, event_id): + return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, (200, content) class FederationV2SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, content class FederationV1InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, room_id, event_id): # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing @@ -588,12 +588,12 @@ async def on_PUT(self, origin, content, query, context, event_id): class FederationV2InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content room_version = content["room_version"] @@ -615,9 +615,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id): - content = await self.handler.on_exchange_third_party_invite_request( - room_id, content - ) + content = await self.handler.on_exchange_third_party_invite_request(content) return 200, content diff --git a/synapse/federation/transport/server.py.orig b/synapse/federation/transport/server.py.orig new file mode 100644 index 00000000000..5e111aa9026 --- /dev/null +++ b/synapse/federation/transport/server.py.orig @@ -0,0 +1,1530 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import re +from typing import Optional, Tuple, Type + +import synapse +from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.api.urls import ( + FEDERATION_UNSTABLE_PREFIX, + FEDERATION_V1_PREFIX, + FEDERATION_V2_PREFIX, +) +from synapse.http.endpoint import parse_and_validate_server_name +from synapse.http.server import JsonResource +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_json_object_from_request, + parse_string_from_args, +) +from synapse.logging.context import run_in_background +from synapse.logging.opentracing import ( + start_active_span, + start_active_span_from_request, + tags, + whitelisted_homeserver, +) +from synapse.server import HomeServer +from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger(__name__) + + +class TransportLayerServer(JsonResource): + """Handles incoming federation HTTP requests""" + + def __init__(self, hs, servlet_groups=None): + """Initialize the TransportLayerServer + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs (synapse.server.HomeServer): homeserver + servlet_groups (list[str], optional): List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + self.hs = hs + self.clock = hs.get_clock() + self.servlet_groups = servlet_groups + + super(TransportLayerServer, self).__init__(hs, canonical_json=False) + + self.authenticator = Authenticator(hs) + self.ratelimiter = FederationRateLimiter( + self.clock, config=hs.config.rc_federation + ) + + self.register_servlets() + + def register_servlets(self): + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + servlet_groups=self.servlet_groups, + ) + + +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + + pass + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + + pass + + +class Authenticator(object): + def __init__(self, hs: HomeServer): + self._clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifier = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() + + # A method just so we can pass 'self' as the authenticator to the Servlets + async def authenticate_request(self, request, content): + now = self._clock.time_msec() + json_request = { + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), + "destination": self.server_name, + "signatures": {}, + } + + if content is not None: + json_request["content"] = content + + origin = None + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if not auth_headers: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + for auth in auth_headers: + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) + json_request["origin"] = origin + json_request["signatures"].setdefault(origin, {})[key] = sig + + if ( + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + + if not json_request["signatures"]: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + await self.keyring.verify_json_for_server( + origin, json_request, now, "Incoming request" + ) + + logger.debug("Request from %s", origin) + request.authenticated_entity = origin + + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = await self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings["retry_last_ts"]: + run_in_background(self._reset_retry_timings, origin) + + return origin + + async def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifier.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + + +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode("utf-8") + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith('"'): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warning( + "Error parsing auth header '%s': %s", + header_bytes.decode("ascii", "replace"), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED + ) + + +class BaseFederationServlet(object): + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods must be *asynchronous* and have the + signature: + + on_(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Optional[Tuple[int, object]]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ + + PATH = "" # Overridden in subclasses, the regex to match against the path. + + REQUIRE_AUTH = True + + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + + def __init__(self, handler, authenticator, ratelimiter, server_name): + self.handler = handler + self.authenticator = authenticator + self.ratelimiter = ratelimiter + + def _wrap(self, func): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @functools.wraps(func) + async def new_func(request, *args, **kwargs): + """A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. + """ + content = None + if request.method in [b"PUT", b"POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + + try: + origin = await authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.warning( + "authenticate_request failed: missing authentication" + ) + raise + except Exception as e: + logger.warning("authenticate_request failed: %s", e) + raise + + request_tags = { + "request_id": request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + "authenticated_entity": origin, + "servlet_name": request.request_metrics.name, + } + + # Only accept the span context if the origin is authenticated + # and whitelisted + if origin and whitelisted_homeserver(origin): + scope = start_active_span_from_request( + request, "incoming-federation-request", tags=request_tags + ) + else: + scope = start_active_span( + "incoming-federation-request", tags=request_tags + ) + + with scope: + if origin: + with ratelimiter.ratelimit(origin) as d: + await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None + response = await func( + origin, content, request.args, *args, **kwargs + ) + else: + response = await func( + origin, content, request.args, *args, **kwargs + ) + + return response + + return new_func + + def register(self, server): + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") + + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue + + server.register_paths( + method, (pattern,), self._wrap(code), self.__class__.__name__, + ) + + +class FederationSendServlet(BaseFederationServlet): + PATH = "/send/(?P[^/]*)/?" + + def __init__(self, handler, server_name, **kwargs): + super(FederationSendServlet, self).__init__( + handler, server_name=server_name, **kwargs + ) + self.server_name = server_name + + # This is when someone is trying to send us a bunch of data. + async def on_PUT(self, origin, content, query, transaction_id): + """ Called on PUT /send// + + Args: + request (twisted.web.http.Request): The HTTP request. + transaction_id (str): The transaction_id associated with this + request. This is *not* None. + + Returns: + Tuple of `(code, response)`, where + `response` is a python dict to be converted into JSON that is + used as the response body. + """ + # Parse the request + try: + transaction_data = content + + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) + + logger.info( + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", + transaction_id, + origin, + len(transaction_data.get("pdus", [])), + len(transaction_data.get("edus", [])), + ) + + # We should ideally be getting this from the security layer. + # origin = body["origin"] + + # Add some extra data to the transaction dict that isn't included + # in the request body. + transaction_data.update( + transaction_id=transaction_id, destination=self.server_name + ) + + except Exception as e: + logger.exception(e) + return 400, {"error": "Invalid transaction"} + + try: + code, response = await self.handler.on_incoming_transaction( + origin, transaction_data + ) + except Exception: + logger.exception("on_incoming_transaction failed") + raise + + return code, response + + +class FederationEventServlet(BaseFederationServlet): + PATH = "/event/(?P[^/]*)/?" + + # This is when someone asks for a data item for a given server data_id pair. + async def on_GET(self, origin, content, query, event_id): + return await self.handler.on_pdu_request(origin, event_id) + + +class FederationStateV1Servlet(BaseFederationServlet): + PATH = "/state/(?P[^/]*)/?" + + # This is when someone asks for all data for a given context. + async def on_GET(self, origin, content, query, context): + return await self.handler.on_context_state_request( + origin, + context, + parse_string_from_args(query, "event_id", None, required=False), + ) + + +class FederationStateIdsServlet(BaseFederationServlet): + PATH = "/state_ids/(?P[^/]*)/?" + + async def on_GET(self, origin, content, query, room_id): + return await self.handler.on_state_ids_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=True), + ) + + +class FederationBackfillServlet(BaseFederationServlet): + PATH = "/backfill/(?P[^/]*)/?" + + async def on_GET(self, origin, content, query, context): + versions = [x.decode("ascii") for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) + + if not limit: + return 400, {"error": "Did not include limit param"} + + return await self.handler.on_backfill_request(origin, context, versions, limit) + + +class FederationQueryServlet(BaseFederationServlet): + PATH = "/query/(?P[^/]*)" + + # This is when we receive a server-server Query + async def on_GET(self, origin, content, query, query_type): + return await self.handler.on_query_request( + query_type, + {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, + ) + + +class FederationMakeJoinServlet(BaseFederationServlet): + PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" + + async def on_GET(self, origin, _content, query, context, user_id): + """ + Args: + origin (unicode): The authenticated server_name of the calling server + + _content (None): (GETs don't have bodies) + + query (dict[bytes, list[bytes]]): Query params from the request. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Tuple[int, object]: (response code, response object) + """ + versions = query.get(b"ver") + if versions is not None: + supported_versions = [v.decode("utf-8") for v in versions] + else: + supported_versions = ["1"] + + content = await self.handler.on_make_join_request( + origin, context, user_id, supported_versions=supported_versions + ) + return 200, content + + +class FederationMakeLeaveServlet(BaseFederationServlet): + PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_GET(self, origin, content, query, context, user_id): + content = await self.handler.on_make_leave_request(origin, context, user_id) + return 200, content + + +class FederationV1SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, content + + +class FederationEventAuthServlet(BaseFederationServlet): + PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" + + async def on_GET(self, origin, content, query, context, event_id): + return await self.handler.on_event_auth(origin, context, event_id) + + +class FederationV1SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, content + + +class FederationV1InviteServlet(BaseFederationServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT(self, origin, content, query, context, event_id): + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + content = await self.handler.on_invite_request( + origin, content, room_version_id=RoomVersions.V1.identifier + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + return 200, (200, content) + + +class FederationV2InviteServlet(BaseFederationServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + content = await self.handler.on_invite_request( + origin, event, room_version_id=room_version + ) + return 200, content + + +class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): + PATH = "/exchange_third_party_invite/(?P[^/]*)" + + async def on_PUT(self, origin, content, query, room_id): + content = await self.handler.on_exchange_third_party_invite_request( + room_id, content + ) + return 200, content + + +class FederationClientKeysQueryServlet(BaseFederationServlet): + PATH = "/user/keys/query" + + async def on_POST(self, origin, content, query): + return await self.handler.on_query_client_keys(origin, content) + + +class FederationUserDevicesQueryServlet(BaseFederationServlet): + PATH = "/user/devices/(?P[^/]*)" + + async def on_GET(self, origin, content, query, user_id): + return await self.handler.on_query_user_devices(origin, user_id) + + +class FederationClientKeysClaimServlet(BaseFederationServlet): + PATH = "/user/keys/claim" + + async def on_POST(self, origin, content, query): + response = await self.handler.on_claim_client_keys(origin, content) + return 200, response + + +class FederationGetMissingEventsServlet(BaseFederationServlet): + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/(?P[^/]*)/?" + + async def on_POST(self, origin, content, query, room_id): + limit = int(content.get("limit", 10)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + content = await self.handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + + return 200, content + + +class On3pidBindServlet(BaseFederationServlet): + PATH = "/3pid/onbind" + + REQUIRE_AUTH = False + + async def on_POST(self, origin, content, query): + if "invites" in content: + last_exception = None + for invite in content["invites"]: + try: + if "signed" not in invite or "token" not in invite["signed"]: + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) + logger.info(message) + raise SynapseError(400, message) + await self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) + except Exception as e: + last_exception = e + if last_exception: + raise last_exception + return 200, {} + + +class OpenIdUserInfo(BaseFederationServlet): + """ + Exchange a bearer token for information about a user. + + The response format should be compatible with: + http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse + + GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "@userpart:example.org", + } + """ + + PATH = "/openid/userinfo" + + REQUIRE_AUTH = False + + async def on_GET(self, origin, content, query): + token = query.get(b"access_token", [None])[0] + if token is None: + return ( + 401, + {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, + ) + + user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) + + if user_id is None: + return ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + + return 200, {"sub": user_id} + + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other homeservers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): + super(PublicRoomList, self).__init__( + handler, authenticator, ratelimiter, server_name + ) + self.allow_access = allow_access + + async def on_GET(self, origin, content, query): + if not self.allow_access: + raise FederationDeniedError(origin) + + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + include_all_networks = parse_boolean_from_args( + query, "include_all_networks", False + ) + third_party_instance_id = parse_string_from_args( + query, "third_party_instance_id", None + ) + + if include_all_networks: + network_tuple = None + elif third_party_instance_id: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + else: + network_tuple = ThirdPartyInstanceID(None, None) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True + ) + return 200, data + + async def on_POST(self, origin, content, query): + # This implements MSC2197 (Search Filtering over Federation) + if not self.allow_access: + raise FederationDeniedError(origin) + + limit = int(content.get("limit", 100)) # type: Optional[int] + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if search_filter is None: + logger.warning("Nonefilter") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + from_federation=True, + ) + + return 200, data + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + async def on_GET(self, origin, content, query): + return ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + + +class FederationGroupsProfileServlet(BaseFederationServlet): + """Get/set the basic profile of a group on behalf of a user + """ + + PATH = "/groups/(?P[^/]*)/profile" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_profile(group_id, requester_user_id) + + return 200, new_content + + async def on_POST(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryServlet(BaseFederationServlet): + PATH = "/groups/(?P[^/]*)/summary" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_summary(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsRoomsServlet(BaseFederationServlet): + """Get the rooms in a group on behalf of a user + """ + + PATH = "/groups/(?P[^/]*)/rooms" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsAddRoomsServlet(BaseFederationServlet): + """Add/remove room from group + """ + + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" + + async def on_POST(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, new_content + + async def on_DELETE(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, new_content + + +class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): + """Update room config in group + """ + + PATH = ( + "/groups/(?P[^/]*)/room/(?P[^/]*)" + "/config/(?P[^/]*)" + ) + + async def on_POST(self, origin, content, query, group_id, room_id, config_key): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = await self.handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class FederationGroupsUsersServlet(BaseFederationServlet): + """Get the users in a group on behalf of a user + """ + + PATH = "/groups/(?P[^/]*)/users" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_users_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsInvitedUsersServlet(BaseFederationServlet): + """Get the users that have been invited to a group + """ + + PATH = "/groups/(?P[^/]*)/invited_users" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, new_content + + +class FederationGroupsInviteServlet(BaseFederationServlet): + """Ask a group server to invite someone to the group + """ + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.invite_to_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsAcceptInviteServlet(BaseFederationServlet): + """Accept an invitation from the group server + """ + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" + + async def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.accept_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsJoinServlet(BaseFederationServlet): + """Attempt to join a group + """ + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" + + async def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.join_group(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveUserServlet(BaseFederationServlet): + """Leave or kick a user from the group + """ + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsLocalInviteServlet(BaseFederationServlet): + """A group server has invited a local user + """ + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + new_content = await self.handler.on_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): + """A group server has removed a local user + """ + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.user_removed_from_group( + group_id, user_id, content + ) + + return 200, new_content + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation + """ + + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + + async def on_POST(self, origin, content, query, group_id, user_id): + # We don't need to check auth here as we check the attestation signatures + + new_content = await self.handler.on_renew_attestation( + group_id, user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)" + ) + + async def on_POST(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class FederationGroupsCategoriesServlet(BaseFederationServlet): + """Get all categories for a group + """ + + PATH = "/groups/(?P[^/]*)/categories/?" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_categories(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsCategoryServlet(BaseFederationServlet): + """Add/remove/get a category in a group + """ + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" + + async def on_GET(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + async def on_POST(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content + ) + + return 200, resp + + async def on_DELETE(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + +class FederationGroupsRolesServlet(BaseFederationServlet): + """Get roles in a group + """ + + PATH = "/groups/(?P[^/]*)/roles/?" + + async def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_roles(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsRoleServlet(BaseFederationServlet): + """Add/remove/get a role in a group + """ + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" + + async def on_GET(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) + + return 200, resp + + async def on_POST(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.update_group_role( + group_id, requester_user_id, role_id, content + ) + + return 200, resp + + async def on_DELETE(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_role( + group_id, requester_user_id, role_id + ) + + return 200, resp + + +class FederationGroupsSummaryUsersServlet(BaseFederationServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)" + ) + + async def on_POST(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): + """Get roles in a group + """ + + PATH = "/get_groups_publicised" + + async def on_POST(self, origin, content, query): + resp = await self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False + ) + + return 200, resp + + +class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): + """Sets whether a group is joinable without an invite or knock + """ + + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" + + async def on_PUT(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, new_content + + +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + async def on_GET(self, origin, content, query, room_id): + + store = self.handler.hs.get_datastore() + + is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = await store.get_room_complexity(room_id) + return 200, complexity + + +FEDERATION_SERVLET_CLASSES = ( + FederationSendServlet, + FederationEventServlet, + FederationStateV1Servlet, + FederationStateIdsServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationMakeLeaveServlet, + FederationEventServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, + FederationGetMissingEventsServlet, + FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, + FederationClientKeysClaimServlet, + FederationThirdPartyInviteExchangeServlet, + On3pidBindServlet, + FederationVersionServlet, + RoomComplexityServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] + +OPENID_SERVLET_CLASSES = ( + OpenIdUserInfo, +) # type: Tuple[Type[BaseFederationServlet], ...] + +ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] + +GROUP_SERVER_SERVLET_CLASSES = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] + + +GROUP_LOCAL_SERVLET_CLASSES = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] + + +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] + +DEFAULT_SERVLET_GROUPS = ( + "federation", + "room_list", + "group_server", + "group_local", + "group_attestation", + "openid", +) + + +def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None): + """Initialize and register servlet classes. + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs (synapse.server.HomeServer): homeserver + resource (TransportLayerServer): resource class to register to + authenticator (Authenticator): authenticator to use + ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use + servlet_groups (list[str], optional): List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + if not servlet_groups: + servlet_groups = DEFAULT_SERVLET_GROUPS + + if "federation" in servlet_groups: + for servletclass in FEDERATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_federation_server(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "openid" in servlet_groups: + for servletclass in OPENID_SERVLET_CLASSES: + servletclass( + handler=hs.get_federation_server(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "room_list" in servlet_groups: + for servletclass in ROOM_LIST_CLASSES: + servletclass( + handler=hs.get_room_list_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + allow_access=hs.config.allow_public_rooms_over_federation, + ).register(resource) + + if "group_server" in servlet_groups: + for servletclass in GROUP_SERVER_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_server_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "group_local" in servlet_groups: + for servletclass in GROUP_LOCAL_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_local_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "group_attestation" in servlet_groups: + for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_attestation_renewer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb78..bbfccba3837 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -55,6 +55,7 @@ from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.handlers._base import BaseHandler +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -2694,7 +2695,7 @@ async def exchange_third_party_invite( ) async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: JsonDict + self, event_dict: JsonDict ) -> None: """Handle an exchange_third_party_invite request from a remote server @@ -2702,12 +2703,11 @@ async def on_exchange_third_party_invite_request( into a normal m.room.member invite. Args: - room_id: The ID of the room. - - event_dict (dict[str, Any]): Dictionary containing the event body. + event_dict: Dictionary containing the event body. """ - room_version = await self.store.get_room_version_id(room_id) + assert_params_in_dict(event_dict, ["room_id"]) + room_version = await self.store.get_room_version_id(event_dict["room_id"]) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. diff --git a/synapse/handlers/federation.py.orig b/synapse/handlers/federation.py.orig new file mode 100644 index 00000000000..593932adb78 --- /dev/null +++ b/synapse/handlers/federation.py.orig @@ -0,0 +1,3017 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains handlers for federation events.""" + +import itertools +import logging +from collections.abc import Container +from http import HTTPStatus +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import attr +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import ( + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + FederationDeniedError, + FederationError, + HttpResponseException, + NotFoundError, + RequestSendFailed, + SynapseError, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions +from synapse.crypto.event_signing import compute_event_signature +from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + preserve_fn, + run_in_background, +) +from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.federation import ( + ReplicationCleanRoomRestServlet, + ReplicationFederationSendEventsRestServlet, + ReplicationStoreRoomOnInviteRestServlet, +) +from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet +from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.distributor import user_joined_room +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr +from synapse.visibility import filter_events_for_server + +logger = logging.getLogger(__name__) + + +@attr.s +class _NewEventInfo: + """Holds information about a received event, ready for passing to _handle_new_events + + Attributes: + event: the received event + + state: the state at that event + + auth_events: the auth_event map for that event + """ + + event = attr.ib(type=EventBase) + state = attr.ib(type=Optional[Sequence[EventBase]], default=None) + auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + + +class FederationHandler(BaseHandler): + """Handles events that originated from federation. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the homeserver (including auth and state conflict resoultion) + b) converting events that were produced by local clients that may need + to be sent to remote homeservers. + c) doing the necessary dances to invite remote users and join remote + rooms. + """ + + def __init__(self, hs): + super(FederationHandler, self).__init__(hs) + + self.hs = hs + + self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state + self.federation_client = hs.get_federation_client() + self.state_handler = hs.get_state_handler() + self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.action_generator = hs.get_action_generator() + self.is_mine_id = hs.is_mine_id + self.pusher_pool = hs.get_pusherpool() + self.spam_checker = hs.get_spam_checker() + self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() + self._server_notices_mxid = hs.config.server_notices_mxid + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self._instance_name = hs.get_instance_name() + self._replication = hs.get_replication_data_handler() + + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) + self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( + hs + ) + self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( + hs + ) + + if hs.config.worker_app: + self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) + self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + hs + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite + + # When joining a room we need to queue any events for that room up + self.room_queues = {} + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + + self.third_party_event_rules = hs.get_third_party_event_rules() + + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: + """ Process a PDU received via a federation /send/ transaction, or + via backfill of missing prev_events + + Args: + origin (str): server which initiated the /send/ transaction. Will + be used to fetch missing events or state. + pdu (FrozenEvent): received PDU + sent_to_us_directly (bool): True if this event was pushed to us; False if + we pulled it as the result of a missing prev_event. + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + logger.info("handling received PDU: %s", pdu) + + # We reprocess pdus when we have seen them only as outliers + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + + already_seen = existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() + ) + if already_seen: + logger.debug("[%s %s]: Already seen pdu", room_id, event_id) + return + + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + logger.warning( + "[%s %s] Received event failed sanity checks", room_id, event_id + ) + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) + + # If we are currently in the process of joining this room, then we + # queue up events for later processing. + if room_id in self.room_queues: + logger.info( + "[%s %s] Queuing PDU from %s for now: join in progress", + room_id, + event_id, + origin, + ) + self.room_queues[room_id].append((pdu, origin)) + return + + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + # + # Note that if we were never in the room then we would have already + # dropped the event, since we wouldn't know the room version. + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + if not is_in_room: + logger.info( + "[%s %s] Ignoring PDU from %s as we're not in the room", + room_id, + event_id, + origin, + ) + return None + + state = None + + # Get missing pdus if necessary. + if not pdu.internal_metadata.is_outlier(): + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + + logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) + + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + + if min_depth is not None and pdu.depth < min_depth: + # This is so that we don't notify the user about this + # message, to work around the fact that some events will + # reference really really old events we really don't want to + # send to the clients. + pdu.internal_metadata.outlier = True + elif min_depth is not None and pdu.depth > min_depth: + missing_prevs = prevs - seen + if sent_to_us_directly and missing_prevs: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "[%s %s] Acquired room lock to fetch %d missing prev_events", + room_id, + event_id, + len(missing_prevs), + ) + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + + if not prevs - seen: + logger.info( + "[%s %s] Found all missing prev_events", + room_id, + event_id, + ) + + if prevs - seen: + # We've still not been able to get all of the prev_events for this event. + # + # In this case, we need to fall back to asking another server in the + # federation for the state at this event. That's ok provided we then + # resolve the state against other bits of the DAG before using it (which + # will ensure that you can't just take over a room by sending an event, + # withholding its prev_events, and declaring yourself to be an admin in + # the subsequent state request). + # + # Now, if we're pulling this event as a missing prev_event, then clearly + # this event is not going to become the only forward-extremity and we are + # guaranteed to resolve its state against our existing forward + # extremities, so that should be fine. + # + # On the other hand, if this event was pushed to us, it is possible for + # it to become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very bad. + # Further, if the event was pushed to us, there is no excuse for us not to + # have all the prev_events. We therefore reject any such events. + # + # XXX this really feels like it could/should be merged with the above, + # but there is an interaction with min_depth that I'm not really + # following. + + if sent_to_us_directly: + logger.warning( + "[%s %s] Rejecting: failed to fetch %d prev events: %s", + room_id, + event_id, + len(prevs - seen), + shortstr(prevs - seen), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: pdu} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps = list(ours.values()) # type: List[StateMap[str]] + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in prevs - seen: + logger.info( + "Requesting state at missing prev_event %s", event_id, + ) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + (remote_state, _,) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await resolve_events_with_store( + self.clock, + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "[%s %s] Error attempting to resolve state at missing " + "prev_events", + room_id, + event_id, + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + + await self._process_received_pdu(origin, pdu, state=state) + + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + """ + Args: + origin (str): Origin of the pdu. Will be called to get the missing events + pdu: received pdu + prevs (set(str)): List of event ids which we are missing + min_depth (int): Minimum depth of events to return. + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + seen = await self.store.have_events_in_timeline(prevs) + + if not prevs - seen: + return + + latest = await self.store.get_latest_event_ids_in_room(room_id) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest) + latest |= seen + + logger.info( + "[%s %s]: Requesting missing events between %s and %s", + room_id, + event_id, + shortstr(latest), + event_id, + ) + + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out agressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the agressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timout to 60s and see what happens. + + try: + missing_events = await self.federation_client.get_missing_events( + origin, + room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=60000, + ) + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: + # We failed to get the missing events, but since we need to handle + # the case of `get_missing_events` not returning the necessary + # events anyway, it is safe to simply log the error and continue. + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) + return + + logger.info( + "[%s %s]: Got %d prev_events: %s", + room_id, + event_id, + len(missing_events), + shortstr(missing_events), + ) + + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + + for ev in missing_events: + logger.info( + "[%s %s] Handling received prev_event %s", + room_id, + event_id, + ev.event_id, + ) + with nested_logging_context(ev.event_id): + try: + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + except FederationError as e: + if e.code == 403: + logger.warning( + "[%s %s] Received prev_event %s failed history check.", + room_id, + event_id, + ev.event_id, + ) + else: + raise + + async def _get_state_for_room( + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. + + Returns: + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + + event_map = await self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + auth_chain.sort(key=lambda e: e.depth) + + return remote_state, auth_chain + + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: + """Fetch events from a remote destination, checking if we already have them. + + Persists any events we don't already have as outliers. + + If we fail to fetch any of the events, a warning will be logged, and the event + will be omitted from the result. Likewise, any events which turn out not to + be in the given room. + + This function *does not* automatically get missing auth events of the + newly fetched events. Callers must include the full auth chain of + of the missing events in the `event_ids` argument, to ensure that any + missing auth events are correctly fetched. + + Returns: + map from event_id to event + """ + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if missing_events: + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + room_id, + ) + + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + return fetched_events + + async def _process_received_pdu( + self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + ): + """ Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event + """ + room_id = event.room_id + event_id = event.event_id + + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) + + try: + context = await self._handle_new_event(origin, event, state=state) + except AuthError as e: + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + + prev_state_ids = await context.get_prev_state_ids() + + prev_state_id = prev_state_ids.get((event.type, event.state_key)) + if prev_state_id: + prev_state = await self.store.get_event( + prev_state_id, allow_none=True + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: + user = UserID.from_string(event.state_key) + await self.user_joined_room(user, room_id) + + # For encrypted messages we check that we know about the sending device, + # if we don't then we mark the device cache for that user as stale. + if event.type == EventTypes.Encrypted: + device_id = event.content.get("device_id") + sender_key = event.content.get("sender_key") + + cached_devices = await self.store.get_cached_devices_for_user(event.sender) + + resync = False # Whether we should resync device lists. + + device = None + if device_id is not None: + device = cached_devices.get(device_id) + if device is None: + logger.info( + "Received event from remote device not in our cache: %s %s", + event.sender, + device_id, + ) + resync = True + + # We also check if the `sender_key` matches what we expect. + if sender_key is not None: + # Figure out what sender key we're expecting. If we know the + # device and recognize the algorithm then we can work out the + # exact key to expect. Otherwise check it matches any key we + # have for that device. + + current_keys = [] # type: Container[str] + + if device: + keys = device.get("keys", {}).get("keys", {}) + + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): + # For this algorithm we expect a curve25519 key. + key_name = "curve25519:%s" % (device_id,) + current_keys = [keys.get(key_name)] + else: + # We don't know understand the algorithm, so we just + # check it matches a key for the device. + current_keys = keys.values() + elif device_id: + # We don't have any keys for the device ID. + pass + else: + # The event didn't include a device ID, so we just look for + # keys across all devices. + current_keys = [ + key + for device in cached_devices + for key in device.get("keys", {}).get("keys", {}).values() + ] + + # We now check that the sender key matches (one of) the expected + # keys. + if sender_key not in current_keys: + logger.info( + "Received event from remote device with unexpected sender key: %s %s: %s", + event.sender, + device_id or "", + sender_key, + ) + resync = True + + if resync: + run_as_background_process( + "resync_device_due_to_pdu", self._resync_device, event.sender + ) + + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) + + @log_function + async def backfill(self, dest, room_id, limit, extremities): + """ Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) + """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + + events = await self.federation_client.backfill( + dest, room_id, limit=limit, extremities=extremities + ) + + # ideally we'd sanity check the events here for excess prev_events etc, + # but it's hard to reject events at this point without completely + # breaking backfill in the same way that it is currently broken by + # events whose signature we cannot verify (#3121). + # + # So for now we accept the events anyway. #3124 tracks this. + # + # for ev in events: + # self._sanity_check_event(ev) + + # Don't bother processing events we already have. + seen_events = await self.store.have_events_in_timeline( + {e.event_id for e in events} + ) + + events = [e for e in events if e.event_id not in seen_events] + + if not events: + return [] + + event_map = {e.event_id: e for e in events} + + event_ids = {e.event_id for e in events} + + # build a list of events whose prev_events weren't in the batch. + # (XXX: this will include events whose prev_events we already have; that doesn't + # sound right?) + edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] + + logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) + + # For each edge get the current state. + + auth_events = {} + state_events = {} + events_to_state = {} + for e_id in edges: + state, auth = await self._get_state_for_room( + destination=dest, + room_id=room_id, + event_id=e_id, + include_event_in_state=False, + ) + auth_events.update({a.event_id: a for a in auth}) + auth_events.update({s.event_id: s for s in state}) + state_events.update({s.event_id: s for s in state}) + events_to_state[e_id] = state + + required_auth = { + a_id + for event in events + + list(state_events.values()) + + list(auth_events.values()) + for a_id in event.auth_event_ids() + } + auth_events.update( + {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} + ) + + ev_infos = [] + + # Step 1: persist the events in the chunk we fetched state for (i.e. + # the backwards extremities), with custom auth events and state + for e_id in events_to_state: + # For paranoia we ensure that these events are marked as + # non-outliers + ev = event_map[e_id] + assert not ev.internal_metadata.is_outlier() + + ev_infos.append( + _NewEventInfo( + event=ev, + state=events_to_state[e_id], + auth_events={ + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in ev.auth_event_ids() + if a_id in auth_events + }, + ) + ) + + await self._handle_new_events(dest, ev_infos, backfilled=True) + + # Step 2: Persist the rest of the events in the chunk one by one + events.sort(key=lambda e: e.depth) + + for event in events: + if event in events_to_state: + continue + + # For paranoia we ensure that these events are marked as + # non-outliers + assert not event.internal_metadata.is_outlier() + + # We store these one at a time since each event depends on the + # previous to work out the state. + # TODO: We can probably do something more clever here. + await self._handle_new_event(dest, event, backfilled=True) + + return events + + async def maybe_backfill(self, room_id, current_depth): + """Checks the database to see if we should backfill before paginating, + and if so do. + """ + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + + if not extremities: + logger.debug("Not backfilling as no extremeties found.") + return + + # We only want to paginate if we can actually see the events we'll get, + # as otherwise we'll just spend a lot of resources to get redacted + # events. + # + # We do this by filtering all the backwards extremities and seeing if + # any remain. Given we don't have the extremity events themselves, we + # need to actually check the events that reference them. + # + # *Note*: the spec wants us to keep backfilling until we reach the start + # of the room in case we are allowed to see some of the history. However + # in practice that causes more issues than its worth, as a) its + # relatively rare for there to be any visible history and b) even when + # there is its often sufficiently long ago that clients would stop + # attempting to paginate before backfill reached the visible history. + # + # TODO: If we do do a backfill then we should filter the backwards + # extremities to only include those that point to visible portions of + # history. + # + # TODO: Correctly handle the case where we are allowed to see the + # forward event but not the backward extremity, e.g. in the case of + # initial join of the server where we are allowed to see the join + # event but not anything before it. This would require looking at the + # state *before* the event, ignoring the special casing certain event + # types have. + + forward_events = await self.store.get_successor_events(list(extremities)) + + extremities_events = await self.store.get_events( + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, + ) + + # We set `check_history_visibility_only` as we might otherwise get false + # positives from users having been erased. + filtered_extremities = await filter_events_for_server( + self.storage, + self.server_name, + list(extremities_events.values()), + redact=False, + check_history_visibility_only=True, + ) + + if not filtered_extremities: + return False + + # Check if we reached a point where we should start backfilling. + sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) + max_depth = sorted_extremeties_tuple[0][1] + + # We don't want to specify too many extremities as it causes the backfill + # request URI to be too long. + extremities = dict(sorted_extremeties_tuple[:5]) + + if current_depth > max_depth: + logger.debug( + "Not backfilling as we don't need to. %d < %d", max_depth, current_depth + ) + return + + # Now we need to decide which hosts to hit first. + + # First we try hosts that are already in the room + # TODO: HEURISTIC ALERT. + + curr_state = await self.state_handler.get_current_state(room_id) + + def get_domains_from_state(state): + """Get joined domains from state + + Args: + state (dict[tuple, FrozenEvent]): State map from type/state + key to event. + + Returns: + list[tuple[str, int]]: Returns a list of servers with the + lowest depth of their joins. Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains = {} # type: Dict[str, int] + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + curr_domains = get_domains_from_state(curr_state) + + likely_domains = [ + domain for domain, depth in curr_domains if domain != self.server_name + ] + + async def try_backfill(domains): + # TODO: Should we try multiple of these at a time? + for dom in domains: + try: + await self.backfill( + dom, room_id, limit=100, extremities=extremities + ) + # If this succeeded then we probably already have the + # appropriate stuff. + # TODO: We can probably do something more intelligent here. + return True + except SynapseError as e: + logger.info("Failed to backfill from %s because %s", dom, e) + continue + except HttpResponseException as e: + if 400 <= e.code < 500: + raise e.to_synapse_error() + + logger.info("Failed to backfill from %s because %s", dom, e) + continue + except CodeMessageException as e: + if 400 <= e.code < 500: + raise + + logger.info("Failed to backfill from %s because %s", dom, e) + continue + except NotRetryingDestination as e: + logger.info(str(e)) + continue + except RequestSendFailed as e: + logger.info("Falied to get backfill from %s because %s", dom, e) + continue + except FederationDeniedError as e: + logger.info(e) + continue + except Exception as e: + logger.exception("Failed to backfill from %s because %s", dom, e) + continue + + return False + + success = await try_backfill(likely_domains) + if success: + return True + + # Huh, well *those* domains didn't work out. Lets try some domains + # from the time. + + tried_domains = set(likely_domains) + tried_domains.add(self.server_name) + + event_ids = list(extremities.keys()) + + logger.debug("calling resolve_state_groups in _maybe_backfill") + resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) + states = await make_deferred_yieldable( + defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], consumeErrors=True + ) + ) + + # dict[str, dict[tuple, str]], a map from event_id to state map of + # event_ids. + states = dict(zip(event_ids, [s.state for s in states])) + + state_map = await self.store.get_events( + [e_id for ids in states.values() for e_id in ids.values()], + get_prev_content=False, + ) + states = { + key: { + k: state_map[e_id] + for k, e_id in state_dict.items() + if e_id in state_map + } + for key, state_dict in states.items() + } + + for e_id, _ in sorted_extremeties_tuple: + likely_domains = get_domains_from_state(states[e_id]) + + success = await try_backfill( + [dom for dom, _ in likely_domains if dom not in tried_domains] + ) + if success: + return True + + tried_domains.update(dom for dom, _ in likely_domains) + + return False + + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ): + """Fetch the given events from a server, and persist them as outliers. + + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_map = {} # type: Dict[str, EventBase] + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], event_id, room_version, outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", destination, event_id, + ) + return + + event_map[event.event_id] = event + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, None, auth)) + + await self._handle_new_events( + destination, event_infos, + ) + + def _sanity_check_event(self, ev): + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev (synapse.events.EventBase): event to be checked + + Returns: None + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_event_ids()) > 20: + logger.warning( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") + + if len(ev.auth_event_ids()) > 10: + logger.warning( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + + async def send_invite(self, target_host, event): + """ Sends the invite to the remote server for signing. + + Invites must be signed by the invitee's server before distribution. + """ + pdu = await self.federation_client.send_invite( + destination=target_host, + room_id=event.room_id, + event_id=event.event_id, + pdu=event, + ) + + return pdu + + async def on_event_auth(self, event_id: str) -> List[EventBase]: + event = await self.store.get_event(event_id) + auth = await self.store.get_auth_chain( + list(event.auth_event_ids()), include_given=True + ) + return list(auth) + + async def do_invite_join( + self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict + ) -> Tuple[str, int]: + """ Attempts to join the `joinee` to the room `room_id` via the + servers contained in `target_hosts`. + + This first triggers a /make_join/ request that returns a partial + event that we can fill out and sign. This is then sent to the + remote server via /send_join/ which responds with the state at that + event and the auth_chains. + + We suspend processing of any received events from this room until we + have finished processing the join. + + Args: + target_hosts: List of servers to attempt to join the room with. + + room_id: The ID of the room to join. + + joinee: The User ID of the joining user. + + content: The event content to use for the join event. + """ + # TODO: We should be able to call this on workers, but the upgrading of + # room stuff after join currently doesn't work on workers. + assert self.config.worker.worker_app is None + + logger.debug("Joining %s to %s", joinee, room_id) + + origin, event, room_version_obj = await self._make_and_verify_event( + target_hosts, + room_id, + joinee, + "join", + content, + params={"ver": KNOWN_ROOM_VERSIONS}, + ) + + # This shouldn't happen, because the RoomMemberHandler has a + # linearizer lock which only allows one operation per user per room + # at a time - so this is just paranoia. + assert room_id not in self.room_queues + + self.room_queues[room_id] = [] + + await self._clean_room_for_join(room_id) + + handled_events = set() + + try: + # Try the host we successfully got a response to /make_join/ + # request first. + host_list = list(target_hosts) + try: + host_list.remove(origin) + host_list.insert(0, origin) + except ValueError: + pass + + ret = await self.federation_client.send_join( + host_list, event, room_version_obj + ) + + origin = ret["origin"] + state = ret["state"] + auth_chain = ret["auth_chain"] + auth_chain.sort(key=lambda e: e.depth) + + handled_events.update([s.event_id for s in state]) + handled_events.update([a.event_id for a in auth_chain]) + handled_events.add(event.event_id) + + logger.debug("do_invite_join auth_chain: %s", auth_chain) + logger.debug("do_invite_join state: %s", state) + + logger.debug("do_invite_join event: %s", event) + + # if this is the first time we've joined this room, it's time to add + # a row to `rooms` with the correct room version. If there's already a + # row there, we should override it, since it may have been populated + # based on an invite request which lied about the room version. + # + # federation_client.send_join has already checked that the room + # version in the received create event is the same as room_version_obj, + # so we can rely on it now. + # + await self.store.upsert_room_on_join( + room_id=room_id, room_version=room_version_obj, + ) + + max_stream_id = await self._persist_auth_tree( + origin, auth_chain, state, event, room_version_obj + ) + + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.config.worker.writers.events, "events", max_stream_id + ) + + # Check whether this room is the result of an upgrade of a room we already know + # about. If so, migrate over user information + predecessor = await self.store.get_room_predecessor(room_id) + if not predecessor or not isinstance(predecessor.get("room_id"), str): + return event.event_id, max_stream_id + old_room_id = predecessor["room_id"] + logger.debug( + "Found predecessor for %s during remote join: %s", room_id, old_room_id + ) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.transfer_room_state_on_room_upgrade( + old_room_id, room_id + ) + + logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id + finally: + room_queue = self.room_queues[room_id] + del self.room_queues[room_id] + + # we don't need to wait for the queued events to be processed - + # it's just a best-effort thing at this point. We do want to do + # them roughly in order, though, otherwise we'll end up making + # lots of requests for missing prev_events which we do actually + # have. Hence we fire off the background task, but don't wait for it. + + run_in_background(self._handle_queued_pdus, room_queue) + + async def _handle_queued_pdus(self, room_queue): + """Process PDUs which got queued up while we were busy send_joining. + + Args: + room_queue (list[FrozenEvent, str]): list of PDUs to be processed + and the servers that sent them + """ + for p, origin in room_queue: + try: + logger.info( + "Processing queued PDU %s which was received " + "while we were joining %s", + p.event_id, + p.room_id, + ) + with nested_logging_context(p.event_id): + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) + except Exception as e: + logger.warning( + "Error handling queued PDU %s from %s: %s", p.event_id, origin, e + ) + + async def on_make_join_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: + """ We've received a /make_join/ request, so we create a partial + join event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: Room to create join event in + user_id: The user to create the join for + """ + if get_domain_from_id(user_id) != origin: + logger.info( + "Got /make_join request for user %r from different origin %s, ignoring", + user_id, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + # checking the room version will check that we've actually heard of the room + # (and return a 404 otherwise) + room_version = await self.store.get_room_version_id(room_id) + + # now check that we are *still* in the room + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + if not is_in_room: + logger.info( + "Got /make_join request for room %s we are no longer in", room_id, + ) + raise NotFoundError("Not an active room on this server") + + event_content = {"membership": Membership.JOIN} + + builder = self.event_builder_factory.new( + room_version, + { + "type": EventTypes.Member, + "content": event_content, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }, + ) + + try: + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + except AuthError as e: + logger.warning("Failed to create join to %s because %s", room_id, e) + raise e + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Creation of join %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + await self.auth.check_from_context( + room_version, event, context, do_sig_check=False + ) + + return event + + async def on_send_join_request(self, origin, pdu): + """ We have received a join event for a room. Fully process it and + respond with the current state and auth chains. + """ + event = pdu + + logger.debug( + "on_send_join_request from %s: Got event: %s, signatures: %s", + origin, + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_join request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + event.internal_metadata.outlier = False + # Send this event on behalf of the origin server. + # + # The reasons we have the destination server rather than the origin + # server send it are slightly mysterious: the origin server should have + # all the neccessary state once it gets the response to the send_join, + # so it could send the event itself if it wanted to. It may be that + # doing it this way reduces failure modes, or avoids certain attacks + # where a new server selectively tells a subset of the federation that + # it has joined. + # + # The fact is that, as of the current writing, Synapse doesn't send out + # the join event over federation after joining, and changing it now + # would introduce the danger of backwards-compatibility problems. + event.internal_metadata.send_on_behalf_of = origin + + context = await self._handle_new_event(origin, event) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of join %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + logger.debug( + "on_send_join_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + user = UserID.from_string(event.state_key) + await self.user_joined_room(user, event.room_id) + + prev_state_ids = await context.get_prev_state_ids() + + state_ids = list(prev_state_ids.values()) + auth_chain = await self.store.get_auth_chain(state_ids) + + state = await self.store.get_events(list(prev_state_ids.values())) + + return {"state": list(state.values()), "auth_chain": auth_chain} + + async def on_invite_request( + self, origin: str, event: EventBase, room_version: RoomVersion + ): + """ We've got an invite event. Process and persist it. Sign it. + + Respond with the now signed event. + """ + if event.state_key is None: + raise SynapseError(400, "The invite event did not have a state key") + + is_blocked = await self.store.is_room_blocked(event.room_id) + if is_blocked: + raise SynapseError(403, "This room has been blocked on this server") + + if self.hs.config.block_non_admin_invites: + raise SynapseError(403, "This server does not accept room invites") + + if not self.spam_checker.user_may_invite( + event.sender, event.state_key, event.room_id + ): + raise SynapseError( + 403, "This user is not permitted to send invites to this server/user" + ) + + membership = event.content.get("membership") + if event.type != EventTypes.Member or membership != Membership.INVITE: + raise SynapseError(400, "The event was not an m.room.member invite event") + + sender_domain = get_domain_from_id(event.sender) + if sender_domain != origin: + raise SynapseError( + 400, "The invite event was not from the server sending it" + ) + + if not self.is_mine_id(event.state_key): + raise SynapseError(400, "The invite event must be for this server") + + # block any attempts to invite the server notices mxid + if event.state_key == self._server_notices_mxid: + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + + # keep a record of the room version, if we don't yet know it. + # (this may get overwritten if we later get a different room version in a + # join dance). + await self._maybe_store_room_on_invite( + room_id=event.room_id, room_version=room_version + ) + + event.internal_metadata.outlier = True + event.internal_metadata.out_of_band_membership = True + + event.signatures.update( + compute_event_signature( + room_version, + event.get_pdu_json(), + self.hs.hostname, + self.hs.signing_key, + ) + ) + + context = await self.state_handler.compute_event_context(event) + await self.persist_events_and_notify([(event, context)]) + + return event + + async def do_remotely_reject_invite( + self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict + ) -> Tuple[EventBase, int]: + origin, event, room_version = await self._make_and_verify_event( + target_hosts, room_id, user_id, "leave", content=content + ) + # Mark as outlier as we don't have any state for this event; we're not + # even in the room. + event.internal_metadata.outlier = True + event.internal_metadata.out_of_band_membership = True + + # Try the host that we succesfully called /make_leave/ on first for + # the /send_leave/ request. + host_list = list(target_hosts) + try: + host_list.remove(origin) + host_list.insert(0, origin) + except ValueError: + pass + + await self.federation_client.send_leave(host_list, event) + + context = await self.state_handler.compute_event_context(event) + stream_id = await self.persist_events_and_notify([(event, context)]) + + return event, stream_id + + async def _make_and_verify_event( + self, + target_hosts: Iterable[str], + room_id: str, + user_id: str, + membership: str, + content: JsonDict = {}, + params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, + ) -> Tuple[str, EventBase, RoomVersion]: + ( + origin, + event, + room_version, + ) = await self.federation_client.make_membership_event( + target_hosts, room_id, user_id, membership, content, params=params + ) + + logger.debug("Got response to make_%s: %s", membership, event) + + # We should assert some things. + # FIXME: Do this in a nicer way + assert event.type == EventTypes.Member + assert event.user_id == user_id + assert event.state_key == user_id + assert event.room_id == room_id + return origin, event, room_version + + async def on_make_leave_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: + """ We've received a /make_leave/ request, so we create a partial + leave event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: Room to create leave event in + user_id: The user to create the leave for + """ + if get_domain_from_id(user_id) != origin: + logger.info( + "Got /make_leave request for user %r from different origin %s, ignoring", + user_id, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + room_version = await self.store.get_room_version_id(room_id) + builder = self.event_builder_factory.new( + room_version, + { + "type": EventTypes.Member, + "content": {"membership": Membership.LEAVE}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }, + ) + + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.warning("Creation of leave %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + await self.auth.check_from_context( + room_version, event, context, do_sig_check=False + ) + except AuthError as e: + logger.warning("Failed to create new leave %r because %s", event, e) + raise e + + return event + + async def on_send_leave_request(self, origin, pdu): + """ We have received a leave event for a room. Fully process it.""" + event = pdu + + logger.debug( + "on_send_leave_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_leave request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + event.internal_metadata.outlier = False + + context = await self._handle_new_event(origin, event) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of leave %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + logger.debug( + "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + return None + + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: + """Returns the state at the event. i.e. not including said event. + """ + + event = await self.store.get_event( + event_id, allow_none=False, check_room_id=room_id + ) + + state_groups = await self.state_store.get_state_groups(room_id, [event_id]) + + if state_groups: + _, state = list(state_groups.items()).pop() + results = {(e.type, e.state_key): e for e in state} + + if event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + prev_event = await self.store.get_event(prev_id) + results[(event.type, event.state_key)] = prev_event + else: + del results[(event.type, event.state_key)] + + res = list(results.values()) + return res + else: + return [] + + async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: + """Returns the state at the event. i.e. not including said event. + """ + event = await self.store.get_event( + event_id, allow_none=False, check_room_id=room_id + ) + + state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) + + if state_groups: + _, state = list(state_groups.items()).pop() + results = state + + if event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + results[(event.type, event.state_key)] = prev_id + else: + results.pop((event.type, event.state_key), None) + + return list(results.values()) + else: + return [] + + @log_function + async def on_backfill_request( + self, origin: str, room_id: str, pdu_list: List[str], limit: int + ) -> List[EventBase]: + in_room = await self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + # Synapse asks for 100 events per backfill request. Do not allow more. + limit = min(limit, 100) + + events = await self.store.get_backfill_events(room_id, pdu_list, limit) + + events = await filter_events_for_server(self.storage, origin, events) + + return events + + @log_function + async def get_persisted_pdu( + self, origin: str, event_id: str + ) -> Optional[EventBase]: + """Get an event from the database for the given server. + + Args: + origin: hostname of server which is requesting the event; we + will check that the server is allowed to see it. + event_id: id of the event being requested + + Returns: + None if we know nothing about the event; otherwise the (possibly-redacted) event. + + Raises: + AuthError if the server is not currently in the room + """ + event = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + + if event: + in_room = await self.auth.check_host_in_room(event.room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + events = await filter_events_for_server(self.storage, origin, [event]) + event = events[0] + return event + else: + return None + + def get_min_depth_for_context(self, context): + return self.store.get_min_depth(context) + + async def _handle_new_event( + self, origin, event, state=None, auth_events=None, backfilled=False + ): + context = await self._prep_event( + origin, event, state=state, auth_events=auth_events, backfilled=backfilled + ) + + try: + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + ): + await self.action_generator.handle_push_actions_for_event( + event, context + ) + + await self.persist_events_and_notify( + [(event, context)], backfilled=backfilled + ) + except Exception: + run_in_background( + self.store.remove_push_actions_from_staging, event.event_id + ) + raise + + return context + + async def _handle_new_events( + self, + origin: str, + event_infos: Iterable[_NewEventInfo], + backfilled: bool = False, + ) -> None: + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + + Notifies about the events where appropriate. + """ + + async def prep(ev_info: _NewEventInfo): + event = ev_info.event + with nested_logging_context(suffix=event.event_id): + res = await self._prep_event( + origin, + event, + state=ev_info.state, + auth_events=ev_info.auth_events, + backfilled=backfilled, + ) + return res + + contexts = await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(prep, ev_info) for ev_info in event_infos], + consumeErrors=True, + ) + ) + + await self.persist_events_and_notify( + [ + (ev_info.event, context) + for ev_info, context in zip(event_infos, contexts) + ], + backfilled=backfilled, + ) + + async def _persist_auth_tree( + self, + origin: str, + auth_events: List[EventBase], + state: List[EventBase], + event: EventBase, + room_version: RoomVersion, + ) -> int: + """Checks the auth chain is valid (and passes auth checks) for the + state and event. Then persists the auth chain and state atomically. + Persists the event separately. Notifies about the persisted events + where appropriate. + + Will attempt to fetch missing auth events. + + Args: + origin: Where the events came from + auth_events + state + event + room_version: The room version we expect this room to have, and + will raise if it doesn't match the version in the create event. + """ + events_to_context = {} + for e in itertools.chain(auth_events, state): + e.internal_metadata.outlier = True + ctx = await self.state_handler.compute_event_context(e) + events_to_context[e.event_id] = ctx + + event_map = { + e.event_id: e for e in itertools.chain(auth_events, state, [event]) + } + + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise SynapseError(400, "No create event in state") + + room_version_id = create_event.content.get( + "room_version", RoomVersions.V1.identifier + ) + + if room_version.identifier != room_version_id: + raise SynapseError(400, "Room version mismatch") + + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id in e.auth_event_ids(): + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = await self.federation_client.get_pdu( + [origin], e_id, room_version=room_version, outlier=True, timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + + for e in itertools.chain(auth_events, state, [event]): + auth_for_e = { + (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] + for e_id in e.auth_event_ids() + if e_id in event_map + } + if create_event: + auth_for_e[(EventTypes.Create, "")] = create_event + + try: + event_auth.check(room_version, e, auth_events=auth_for_e) + except SynapseError as err: + # we may get SynapseErrors here as well as AuthErrors. For + # instance, there are a couple of (ancient) events in some + # rooms whose senders do not have the correct sigil; these + # cause SynapseErrors in auth.check. We don't want to give up + # the attempt to federate altogether in such cases. + + logger.warning("Rejecting %s because %s", e.event_id, err.msg) + + if e == event: + raise + events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + + await self.persist_events_and_notify( + [ + (e, events_to_context[e.event_id]) + for e in itertools.chain(auth_events, state) + ] + ) + + new_event_context = await self.state_handler.compute_event_context( + event, old_state=state + ) + + return await self.persist_events_and_notify([(event, new_event_context)]) + + async def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[StateMap[EventBase]], + backfilled: bool, + ) -> EventContext: + context = await self.state_handler.compute_event_context(event, old_state=state) + + if not auth_events: + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.auth.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. + if event.type == EventTypes.Member and not event.auth_event_ids(): + if len(event.prev_event_ids()) == 1 and event.depth < 5: + c = await self.store.get_event( + event.prev_event_ids()[0], allow_none=True + ) + if c and c.type == EventTypes.Create: + auth_events[(c.type, c.state_key)] = c + + context = await self.do_auth(origin, event, context, auth_events=auth_events) + + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + return context + + async def _check_for_soft_fail( + self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + ) -> None: + """Checks if we should soft fail the event; if so, marks the event as + such. + + Args: + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill + """ + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + if backfilled or event.internal_metadata.is_outlier(): + return + + extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids) + prev_event_ids = set(event.prev_event_ids()) + + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets = list(state_sets.values()) + state_sets.append(state) + current_state_ids = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids + ) + + logger.debug( + "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + ) + + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(event) + current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] + + current_auth_events = await self.store.get_events(current_state_ids) + current_auth_events = { + (e.type, e.state_key): e for e in current_auth_events.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning("Soft-failing %r because %s", event, e) + event.internal_metadata.soft_failed = True + + async def on_query_auth( + self, origin, event_id, room_id, remote_auth_chain, rejects, missing + ): + in_room = await self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + event = await self.store.get_event( + event_id, allow_none=False, check_room_id=room_id + ) + + # Just go through and process each event in `remote_auth_chain`. We + # don't want to fall into the trap of `missing` being wrong. + for e in remote_auth_chain: + try: + await self._handle_new_event(origin, e) + except AuthError: + pass + + # Now get the current auth_chain for the event. + local_auth_chain = await self.store.get_auth_chain( + list(event.auth_event_ids()), include_given=True + ) + + # TODO: Check if we would now reject event_id. If so we need to tell + # everyone. + + ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain) + + logger.debug("on_query_auth returning: %s", ret) + + return ret + + async def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): + in_room = await self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + # Only allow up to 20 events to be retrieved per request. + limit = min(limit, 20) + + missing_events = await self.store.get_missing_events( + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + + missing_events = await filter_events_for_server( + self.storage, origin, missing_events + ) + + return missing_events + + async def do_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + auth_events: StateMap[EventBase], + ) -> EventContext: + """ + + Args: + origin: + event: + context: + auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Also NB that this function adds entries to it. + Returns: + updated context object + """ + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + try: + context = await self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + except Exception: + # We don't really mind if the above fails, so lets not fail + # processing if it does. However, it really shouldn't fail so + # let's still log as an exception since we'll still want to fix + # any bugs. + logger.exception( + "Failed to double check auth events for %s with remote. " + "Ignoring failure and continuing processing of event.", + event.event_id, + ) + + try: + event_auth.check(room_version_obj, event, auth_events=auth_events) + except AuthError as e: + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + return context + + async def _update_auth_events_and_context_for_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + auth_events: StateMap[EventBase], + ) -> EventContext: + """Helper for do_auth. See there for docs. + + Checks whether a given event has the expected auth events. If it + doesn't then we talk to the remote server to compare state to see if + we can come to a consensus (e.g. if one server missed some valid + state). + + This attempts to resolve any potential divergence of state between + servers, but is not essential and so failures should not block further + processing of the event. + + Args: + origin: + event: + context: + + auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Also NB that this function adds entries to it. + + Returns: + updated context + """ + event_auth_events = set(event.auth_event_ids()) + + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. + if missing_auth: + have_events = await self.store.have_seen_events(missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) + + if missing_auth: + # If we don't have all the auth events, we need to get them. + logger.info("auth_events contains unknown events: %s", missing_auth) + try: + try: + remote_auth_chain = await self.federation_client.get_event_auth( + origin, event.room_id, event.event_id + ) + except RequestSendFailed as e1: + # The other side isn't around or doesn't implement the + # endpoint, so lets just bail out. + logger.info("Failed to get event auth from remote: %s", e1) + return context + + seen_remotes = await self.store.have_seen_events( + [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes: + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = e.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in remote_auth_chain + if e.event_id in auth_ids or e.type == EventTypes.Create + } + e.internal_metadata.outlier = True + + logger.debug( + "do_auth %s missing_auth: %s", event.event_id, e.event_id + ) + await self._handle_new_event(origin, e, auth_events=auth) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + except Exception: + logger.exception("Failed to get auth chain") + + if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? + logger.info("Skipping auth_event fetch for outlier") + return context + + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if not different_auth: + return context + + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) + + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = await self.store.get_events_as_list(different_auth) + + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) + + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context + + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. + + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = await self.store.get_room_version_id(event.room_id) + new_state = await self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) + + auth_events.update(new_state) + + context = await self._update_context_for_auth_events( + event, context, auth_events + ) + + return context + + async def _update_context_for_auth_events( + self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] + ) -> EventContext: + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event: The event we're handling the context for + + context: initial event context + + auth_events: Events to update in the event context. + + Returns: + new event context + """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] + else: + event_key = None + state_updates = { + k: a.event_id for k, a in auth_events.items() if k != event_key + } + + current_state_ids = await context.get_current_state_ids() + current_state_ids = dict(current_state_ids) # type: ignore + + current_state_ids.update(state_updates) + + prev_state_ids = await context.get_prev_state_ids() + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) + + # create a new state group as a delta from the existing one. + prev_group = context.state_group + state_group = await self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=state_updates, + current_state_ids=current_state_ids, + ) + + return EventContext.with_state( + state_group=state_group, + state_group_before_event=context.state_group_before_event, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=state_updates, + ) + + async def construct_auth_difference( + self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] + ) -> Dict: + """ Given a local and remote auth chain, find the differences. This + assumes that we have already processed all events in remote_auth + + Params: + local_auth (list) + remote_auth (list) + + Returns: + dict + """ + + logger.debug("construct_auth_difference Start!") + + # TODO: Make sure we are OK with local_auth or remote_auth having more + # auth events in them than strictly necessary. + + def sort_fun(ev): + return ev.depth, ev.event_id + + logger.debug("construct_auth_difference after sort_fun!") + + # We find the differences by starting at the "bottom" of each list + # and iterating up on both lists. The lists are ordered by depth and + # then event_id, we iterate up both lists until we find the event ids + # don't match. Then we look at depth/event_id to see which side is + # missing that event, and iterate only up that list. Repeat. + + remote_list = list(remote_auth) + remote_list.sort(key=sort_fun) + + local_list = list(local_auth) + local_list.sort(key=sort_fun) + + local_iter = iter(local_list) + remote_iter = iter(remote_list) + + logger.debug("construct_auth_difference before get_next!") + + def get_next(it, opt=None): + try: + return next(it) + except Exception: + return opt + + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + + logger.debug("construct_auth_difference before while") + + missing_remotes = [] + missing_locals = [] + while current_local or current_remote: + if current_remote is None: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local is None: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + if current_local.event_id == current_remote.event_id: + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + continue + + if current_local.depth < current_remote.depth: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local.depth > current_remote.depth: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + # They have the same depth, so we fall back to the event_id order + if current_local.event_id < current_remote.event_id: + missing_locals.append(current_local) + current_local = get_next(local_iter) + + if current_local.event_id > current_remote.event_id: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + logger.debug("construct_auth_difference after while") + + # missing locals should be sent to the server + # We should find why we are missing remotes, as they will have been + # rejected. + + # Remove events from missing_remotes if they are referencing a missing + # remote. We only care about the "root" rejected ones. + missing_remote_ids = [e.event_id for e in missing_remotes] + base_remote_rejected = list(missing_remotes) + for e in missing_remotes: + for e_id in e.auth_event_ids(): + if e_id in missing_remote_ids: + try: + base_remote_rejected.remove(e) + except ValueError: + pass + + reason_map = {} + + for e in base_remote_rejected: + reason = await self.store.get_rejection_reason(e.event_id) + if reason is None: + # TODO: e is not in the current state, so we should + # construct some proof of that. + continue + + reason_map[e.event_id] = reason + + logger.debug("construct_auth_difference returning") + + return { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } + + @log_function + async def exchange_third_party_invite( + self, sender_user_id, target_user_id, room_id, signed + ): + third_party_invite = {"signed": signed} + + event_dict = { + "type": EventTypes.Member, + "content": { + "membership": Membership.INVITE, + "third_party_invite": third_party_invite, + }, + "room_id": room_id, + "sender": sender_user_id, + "state_key": target_user_id, + } + + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) + builder = self.event_builder_factory.new(room_version, event_dict) + + EventValidator().validate_builder(builder) + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info( + "Creation of threepid invite %s forbidden by third-party rules", + event, + ) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + event, context = await self.add_display_name_to_third_party_invite( + room_version, event_dict, event, context + ) + + EventValidator().validate_new(event, self.config) + + # We need to tell the transaction queue to send this out, even + # though the sender isn't a local user. + event.internal_metadata.send_on_behalf_of = self.hs.hostname + + try: + await self.auth.check_from_context(room_version, event, context) + except AuthError as e: + logger.warning("Denying new third party invite %r because %s", event, e) + raise e + + await self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.send_membership_event(None, event, context) + else: + destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} + await self.federation_client.forward_third_party_invite( + destinations, room_id, event_dict + ) + + async def on_exchange_third_party_invite_request( + self, room_id: str, event_dict: JsonDict + ) -> None: + """Handle an exchange_third_party_invite request from a remote server + + The remote server will call this when it wants to turn a 3pid invite + into a normal m.room.member invite. + + Args: + room_id: The ID of the room. + + event_dict (dict[str, Any]): Dictionary containing the event body. + + """ + room_version = await self.store.get_room_version_id(room_id) + + # NB: event_dict has a particular specced format we might need to fudge + # if we change event formats too much. + builder = self.event_builder_factory.new(room_version, event_dict) + + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.warning( + "Exchange of threepid invite %s forbidden by third-party rules", event + ) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + event, context = await self.add_display_name_to_third_party_invite( + room_version, event_dict, event, context + ) + + try: + await self.auth.check_from_context(room_version, event, context) + except AuthError as e: + logger.warning("Denying third party invite %r because %s", event, e) + raise e + await self._check_signature(event, context) + + # We need to tell the transaction queue to send this out, even + # though the sender isn't a local user. + event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.send_membership_event(None, event, context) + + async def add_display_name_to_third_party_invite( + self, room_version, event_dict, event, context + ): + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"], + ) + original_invite = None + prev_state_ids = await context.get_prev_state_ids() + original_invite_id = prev_state_ids.get(key) + if original_invite_id: + original_invite = await self.store.get_event( + original_invite_id, allow_none=True + ) + if original_invite: + # If the m.room.third_party_invite event's content is empty, it means the + # invite has been revoked. In this case, we don't have to raise an error here + # because the auth check will fail on the invite (because it's not able to + # fetch public keys from the m.room.third_party_invite event's content, which + # is empty). + display_name = original_invite.content.get("display_name") + event_dict["content"]["third_party_invite"]["display_name"] = display_name + else: + logger.info( + "Could not find invite event for third_party_invite: %r", event_dict + ) + # We don't discard here as this is not the appropriate place to do + # auth checks. If we need the invite and don't have it then the + # auth check code will explode appropriately. + + builder = self.event_builder_factory.new(room_version, event_dict) + EventValidator().validate_builder(builder) + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + EventValidator().validate_new(event, self.config) + return (event, context) + + async def _check_signature(self, event, context): + """ + Checks that the signature in the event is consistent with its invite. + + Args: + event (Event): The m.room.member event to check + context (EventContext): + + Raises: + AuthError: if signature didn't match any keys, or key has been + revoked, + SynapseError: if a transient error meant a key couldn't be checked + for revocation. + """ + signed = event.content["third_party_invite"]["signed"] + token = signed["token"] + + prev_state_ids = await context.get_prev_state_ids() + invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) + + invite_event = None + if invite_event_id: + invite_event = await self.store.get_event(invite_event_id, allow_none=True) + + if not invite_event: + raise AuthError(403, "Could not find invite") + + logger.debug("Checking auth on event %r", event.content) + + last_exception = None # type: Optional[Exception] + + # for each public key in the 3pid invite event + for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + try: + # for each sig on the third_party_invite block of the actual invite + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + + logger.debug( + "Attempting to verify sig with key %s from %r " + "against pubkey %r", + key_name, + server, + public_key_object, + ) + + try: + public_key = public_key_object["public_key"] + verify_key = decode_verify_key_bytes( + key_name, decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + logger.debug( + "Successfully verified sig with key %s from %r " + "against pubkey %r", + key_name, + server, + public_key_object, + ) + except Exception: + logger.info( + "Failed to verify sig with key %s from %r " + "against pubkey %r", + key_name, + server, + public_key_object, + ) + raise + try: + if "key_validity_url" in public_key_object: + await self._check_key_revocation( + public_key, public_key_object["key_validity_url"] + ) + except Exception: + logger.info( + "Failed to query key_validity_url %s", + public_key_object["key_validity_url"], + ) + raise + return + except Exception as e: + last_exception = e + + if last_exception is None: + # we can only get here if get_public_keys() returned an empty list + # TODO: make this better + raise RuntimeError("no public key in invite event") + + raise last_exception + + async def _check_key_revocation(self, public_key, url): + """ + Checks whether public_key has been revoked. + + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. + + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked + for revocation. + """ + try: + response = await self.http_client.get_json(url, {"public_key": public_key}) + except Exception: + raise SynapseError(502, "Third party certificate could not be checked") + if "valid" not in response or not response["valid"]: + raise AuthError(403, "Third party certificate was invalid") + + async def persist_events_and_notify( + self, + event_and_contexts: Sequence[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> int: + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + event_and_contexts: + backfilled: Whether these events are a result of + backfilling or not + """ + if self.config.worker.writers.events != self._instance_name: + result = await self._send_events( + instance_name=self.config.worker.writers.events, + store=self.store, + event_and_contexts=event_and_contexts, + backfilled=backfilled, + ) + return result["max_stream_id"] + else: + max_stream_id = await self.storage.persistence.persist_events( + event_and_contexts, backfilled=backfilled + ) + + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + + if not backfilled: # Never notify for backfilled events + for event, _ in event_and_contexts: + await self._notify_persisted_event(event, max_stream_id) + + return max_stream_id + + async def _notify_persisted_event( + self, event: EventBase, max_stream_id: int + ) -> None: + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event: + max_stream_id: The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + event_stream_id = event.internal_metadata.stream_ordering + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) + + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + + async def _clean_room_for_join(self, room_id: str) -> None: + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Args: + room_id + """ + if self.config.worker_app: + await self._clean_room_for_join_client(room_id) + else: + await self.store.clean_room_for_join(room_id) + + async def user_joined_room(self, user: UserID, room_id: str) -> None: + """Called when a new user has joined the room + """ + if self.config.worker_app: + await self._notify_user_membership_change( + room_id=room_id, user_id=user.to_string(), change="joined" + ) + else: + user_joined_room(self.distributor, user, room_id) + + async def get_room_complexity( + self, remote_room_hosts: List[str], room_id: str + ) -> Optional[dict]: + """ + Fetch the complexity of a remote room over federation. + + Args: + remote_room_hosts (list[str]): The remote servers to ask. + room_id (str): The room ID to ask about. + + Returns: + Dict contains the complexity + metric versions, while None means we could not fetch the complexity. + """ + + for host in remote_room_hosts: + res = await self.federation_client.get_room_complexity(host, room_id) + + # We got a result, return it. + if res: + return res + + # We fell off the bottom, couldn't get the complexity from anyone. Oh + # well. + return None diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 96fea586730..30e71b6c153 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -59,7 +59,6 @@ def test_exchange_revoked_invite(self): ) d = self.handler.on_exchange_third_party_invite_request( - room_id=room_id, event_dict={ "type": EventTypes.Member, "room_id": room_id, From d1b3d8a426e446ccbbb267b98d45e4923d4d8ef9 Mon Sep 17 00:00:00 2001 From: turly221 Date: Fri, 6 Dec 2024 08:37:30 +0000 Subject: [PATCH 2/2] commit patch 22916874 --- synapse/groups/groups_server.py | 18 +- synapse/groups/groups_server.py.orig | 946 ++++++++++++++++++++++ tests/rest/client/v2_alpha/test_groups.py | 43 + 3 files changed, 1005 insertions(+), 2 deletions(-) create mode 100644 synapse/groups/groups_server.py.orig create mode 100644 tests/rest/client/v2_alpha/test_groups.py diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8cb922ddc73..d340dd5e70a 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -309,6 +309,13 @@ async def get_rooms_in_group(self, group_id, requester_user_id): requester_user_id, group_id ) + # Note! room_results["is_public"] is about whether the room is considered + # public from the group's point of view. (i.e. whether non-group members + # should be able to see the room is in the group). + # This is not the same as whether the room itself is public (in the sense + # of being visible in the room directory). + # As such, room_results["is_public"] itself is not sufficient to determine + # whether any given user is permitted to see the room's metadata. room_results = await self.store.get_rooms_in_group( group_id, include_private=is_user_in_group ) @@ -318,8 +325,15 @@ async def get_rooms_in_group(self, group_id, requester_user_id): room_id = room_result["room_id"] joined_users = await self.store.get_users_in_room(room_id) + + # check the user is actually allowed to see the room before showing it to them + allow_private = requester_user_id in joined_users + entry = await self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True + room_id, + len(joined_users), + with_alias=False, + allow_private=allow_private, ) if not entry: @@ -331,7 +345,7 @@ async def get_rooms_in_group(self, group_id, requester_user_id): chunk.sort(key=lambda e: -e["num_joined_members"]) - return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + return {"chunk": chunk, "total_room_count_estimate": len(chunk)} class GroupsServerHandler(GroupsServerWorkerHandler): diff --git a/synapse/groups/groups_server.py.orig b/synapse/groups/groups_server.py.orig new file mode 100644 index 00000000000..8cb922ddc73 --- /dev/null +++ b/synapse/groups/groups_server.py.orig @@ -0,0 +1,946 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.types import GroupID, RoomID, UserID, get_domain_from_id +from synapse.util.async_helpers import concurrently_execute + +logger = logging.getLogger(__name__) + + +# TODO: Allow users to "knock" or simply join depending on rules +# TODO: Federation admin APIs +# TODO: is_privileged flag to users and is_public to users and rooms +# TODO: Audit log for admins (profile updates, membership changes, users who tried +# to join but were rejected, etc) +# TODO: Flairs + + +class GroupsServerWorkerHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.signing_key + self.server_name = hs.hostname + self.attestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.profile_handler = hs.get_profile_handler() + + async def check_group_is_ours( + self, group_id, requester_user_id, and_exists=False, and_is_admin=None + ): + """Check that the group is ours, and optionally if it exists. + + If group does exist then return group. + + Args: + group_id (str) + and_exists (bool): whether to also check if group exists + and_is_admin (str): whether to also check if given str is a user_id + that is an admin + """ + if not self.is_mine_id(group_id): + raise SynapseError(400, "Group not on this server") + + group = await self.store.get_group(group_id) + if and_exists and not group: + raise SynapseError(404, "Unknown group") + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + + if and_is_admin: + is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + return group + + async def get_group_summary(self, group_id, requester_user_id): + """Get the summary for a group as seen by requester_user_id. + + The group summary consists of the profile of the room, and a curated + list of users and rooms. These list *may* be organised by role/category. + The roles/categories are ordered, and so are the users/rooms within them. + + A user/room may appear in multiple roles/categories. + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + profile = await self.get_group_profile(group_id, requester_user_id) + + users, roles = await self.store.get_users_for_summary_by_role( + group_id, include_private=is_user_in_group + ) + + # TODO: Add profiles to users + + rooms, categories = await self.store.get_rooms_for_summary_by_category( + group_id, include_private=is_user_in_group + ) + + for room_entry in rooms: + room_id = room_entry["room_id"] + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( + room_id, len(joined_users), with_alias=False, allow_private=True + ) + entry = dict(entry) # so we don't change whats cached + entry.pop("room_id", None) + + room_entry["profile"] = entry + + rooms.sort(key=lambda e: e.get("order", 0)) + + for entry in users: + user_id = entry["user_id"] + + if not self.is_mine_id(requester_user_id): + attestation = await self.store.get_remote_attestation(group_id, user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, user_id + ) + + user_profile = await self.profile_handler.get_profile_from_cache(user_id) + entry.update(user_profile) + + users.sort(key=lambda e: e.get("order", 0)) + + membership_info = await self.store.get_users_membership_info_in_group( + group_id, requester_user_id + ) + + return { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } + + async def get_group_categories(self, group_id, requester_user_id): + """Get all categories in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = await self.store.get_group_categories(group_id=group_id) + return {"categories": categories} + + async def get_group_category(self, group_id, requester_user_id, category_id): + """Get a specific category in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = await self.store.get_group_category( + group_id=group_id, category_id=category_id + ) + + logger.info("group %s", res) + + return res + + async def get_group_roles(self, group_id, requester_user_id): + """Get all roles in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = await self.store.get_group_roles(group_id=group_id) + return {"roles": roles} + + async def get_group_role(self, group_id, requester_user_id, role_id): + """Get a specific role in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = await self.store.get_group_role(group_id=group_id, role_id=role_id) + return res + + async def get_group_profile(self, group_id, requester_user_id): + """Get the group profile as seen by requester_user_id + """ + + await self.check_group_is_ours(group_id, requester_user_id) + + group = await self.store.get_group(group_id) + + if group: + cols = [ + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" + + return group_description + else: + raise SynapseError(404, "Unknown group") + + async def get_users_in_group(self, group_id, requester_user_id): + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = await self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = await self.store.get_remote_attestation( + group_id, g_user_id + ) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + return {"chunk": chunk, "total_user_count_estimate": len(user_results)} + + async def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = await self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = {"user_id": user_id} + try: + profile = await self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warning("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + + async def get_rooms_in_group(self, group_id, requester_user_id): + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + room_results = await self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( + room_id, len(joined_users), with_alias=False, allow_private=True + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + + +class GroupsServerHandler(GroupsServerWorkerHandler): + def __init__(self, hs): + super(GroupsServerHandler, self).__init__(hs) + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + async def update_group_summary_room( + self, group_id, requester_user_id, room_id, category_id, content + ): + """Add/update a room to the group summary + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + RoomID.from_string(room_id) # Ensure valid room id + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_room_to_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + order=order, + is_public=is_public, + ) + + return {} + + async def delete_group_summary_room( + self, group_id, requester_user_id, room_id, category_id + ): + """Remove a room from the summary + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_room_from_summary( + group_id=group_id, room_id=room_id, category_id=category_id + ) + + return {} + + async def set_group_join_policy(self, group_id, requester_user_id, content): + """Sets the group join policy. + + Currently supported policies are: + - "invite": an invite must be received and accepted in order to join. + - "open": anyone can join. + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + join_policy = _parse_join_policy_from_contents(content) + if join_policy is None: + raise SynapseError(400, "No value specified for 'm.join_policy'") + + await self.store.set_group_join_policy(group_id, join_policy=join_policy) + + return {} + + async def update_group_category( + self, group_id, requester_user_id, category_id, content + ): + """Add/Update a group category + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + profile = content.get("profile") + + await self.store.upsert_group_category( + group_id=group_id, + category_id=category_id, + is_public=is_public, + profile=profile, + ) + + return {} + + async def delete_group_category(self, group_id, requester_user_id, category_id): + """Delete a group category + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_group_category( + group_id=group_id, category_id=category_id + ) + + return {} + + async def update_group_role(self, group_id, requester_user_id, role_id, content): + """Add/update a role in a group + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + profile = content.get("profile") + + await self.store.upsert_group_role( + group_id=group_id, role_id=role_id, is_public=is_public, profile=profile + ) + + return {} + + async def delete_group_role(self, group_id, requester_user_id, role_id): + """Remove role from group + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_group_role(group_id=group_id, role_id=role_id) + + return {} + + async def update_group_summary_user( + self, group_id, requester_user_id, user_id, role_id, content + ): + """Add/update a users entry in the group summary + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_user_to_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + order=order, + is_public=is_public, + ) + + return {} + + async def delete_group_summary_user( + self, group_id, requester_user_id, user_id, role_id + ): + """Remove a user from the group summary + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_user_from_summary( + group_id=group_id, user_id=user_id, role_id=role_id + ) + + return {} + + async def update_group_profile(self, group_id, requester_user_id, content): + """Update the group profile + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + profile = {} + for keyname in ("name", "avatar_url", "short_description", "long_description"): + if keyname in content: + value = content[keyname] + if not isinstance(value, str): + raise SynapseError(400, "%r value is not a string" % (keyname,)) + profile[keyname] = value + + await self.store.update_group_profile(group_id, profile) + + async def add_room_to_group(self, group_id, requester_user_id, room_id, content): + """Add room to group + """ + RoomID.from_string(room_id) # Ensure valid room id + + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_room_to_group(group_id, room_id, is_public=is_public) + + return {} + + async def update_room_in_group( + self, group_id, requester_user_id, room_id, config_key, content + ): + """Update room in group + """ + RoomID.from_string(room_id) # Ensure valid room id + + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + if config_key == "m.visibility": + is_public = _parse_visibility_dict(content) + + await self.store.update_room_in_group_visibility( + group_id, room_id, is_public=is_public + ) + else: + raise SynapseError(400, "Uknown config option") + + return {} + + async def remove_room_from_group(self, group_id, requester_user_id, room_id): + """Remove room from group + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_room_from_group(group_id, room_id) + + return {} + + async def invite_to_group(self, group_id, user_id, requester_user_id, content): + """Invite user to group + """ + + group = await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + # TODO: Check if user knocked + + invited_users = await self.store.get_invited_users_in_group(group_id) + if user_id in invited_users: + raise SynapseError( + 400, "User already invited to group", errcode=Codes.BAD_STATE + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=True + ) + if user_id in (user_result["user_id"] for user_result in user_results): + raise SynapseError(400, "User already in group") + + content = { + "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, + "inviter": requester_user_id, + } + + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + res = await groups_local.on_invite(group_id, user_id, content) + local_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content.update({"attestation": local_attestation}) + + res = await self.transport_client.invite_to_group_notification( + get_domain_from_id(user_id), group_id, user_id, content + ) + + user_profile = res.get("user_profile", {}) + await self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + if res["state"] == "join": + if not self.hs.is_mine_id(user_id): + remote_attestation = res["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=user_id, group_id=group_id + ) + else: + remote_attestation = None + + await self.store.add_user_to_group( + group_id, + user_id, + is_admin=False, + is_public=False, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + elif res["state"] == "invite": + await self.store.add_group_invite(group_id, user_id) + return {"state": "invite"} + elif res["state"] == "reject": + return {"state": "reject"} + else: + raise SynapseError(502, "Unknown state returned by HS") + + async def _add_user(self, group_id, user_id, content): + """Add a user to a group based on a content dict. + + See accept_invite, join_group. + """ + if not self.hs.is_mine_id(user_id): + local_attestation = self.attestations.create_attestation(group_id, user_id) + + remote_attestation = content["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=user_id, group_id=group_id + ) + else: + local_attestation = None + remote_attestation = None + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_user_to_group( + group_id, + user_id, + is_admin=False, + is_public=is_public, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + return local_attestation + + async def accept_invite(self, group_id, requester_user_id, content): + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = await self.store.is_user_invited_to_local_group( + group_id, requester_user_id + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + local_attestation = await self._add_user(group_id, requester_user_id, content) + + return {"state": "join", "attestation": local_attestation} + + async def join_group(self, group_id, requester_user_id, content): + """User tries to join the group. + + This will error if the group requires an invite/knock to join + """ + + group_info = await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True + ) + if group_info["join_policy"] != "open": + raise SynapseError(403, "Group is not publicly joinable") + + local_attestation = await self._add_user(group_id, requester_user_id, content) + + return {"state": "join", "attestation": local_attestation} + + async def knock(self, group_id, requester_user_id, content): + """A user requests becoming a member of the group + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + async def accept_knock(self, group_id, requester_user_id, content): + """Accept a users knock to the room. + + Errors if the user hasn't knocked, rather than inviting them. + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): + """Remove a user from the group; either a user is leaving or an admin + kicked them. + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_kick = False + if requester_user_id != user_id: + is_admin = await self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + is_kick = True + + await self.store.remove_user_from_group(group_id, user_id) + + if is_kick: + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + await groups_local.user_removed_from_group(group_id, user_id, {}) + else: + await self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + + if not self.hs.is_mine_id(user_id): + await self.store.maybe_delete_remote_profile_cache(user_id) + + # Delete group if the last user has left + users = await self.store.get_users_in_group(group_id, include_private=True) + if not users: + await self.store.delete_group(group_id) + + return {} + + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) + + logger.info("Attempting to create group with ID: %r", group_id) + + # parsing the id into a GroupID validates it. + group_id_obj = GroupID.from_string(group_id) + + if group: + raise SynapseError(400, "Group already exists") + + is_admin = await self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) + if not is_admin: + if not self.hs.config.enable_group_creation: + raise SynapseError( + 403, "Only a server admin can create groups on this server" + ) + localpart = group_id_obj.localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" + % (self.hs.config.group_creation_prefix,), + ) + + profile = content.get("profile", {}) + name = profile.get("name") + avatar_url = profile.get("avatar_url") + short_description = profile.get("short_description") + long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) + + await self.store.create_group( + group_id, + requester_user_id, + name=name, + avatar_url=avatar_url, + short_description=short_description, + long_description=long_description, + ) + + if not self.hs.is_mine_id(requester_user_id): + remote_attestation = content["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=requester_user_id, group_id=group_id + ) + + local_attestation = self.attestations.create_attestation( + group_id, requester_user_id + ) + else: + local_attestation = None + remote_attestation = None + + await self.store.add_user_to_group( + group_id, + requester_user_id, + is_admin=True, + is_public=True, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + if not self.hs.is_mine_id(requester_user_id): + await self.store.add_remote_profile_cache( + requester_user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + return {"group_id": group_id} + + async def delete_group(self, group_id, requester_user_id): + """Deletes a group, kicking out all current members. + + Only group admins or server admins can call this request + + Args: + group_id (str) + request_user_id (str) + + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + # Only server admins or group admins can delete groups. + + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) + + if not is_admin: + is_admin = await self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) + + if not is_admin: + raise SynapseError(403, "User is not an admin") + + # Before deleting the group lets kick everyone out of it + users = await self.store.get_users_in_group(group_id, include_private=True) + + async def _kick_user_from_group(user_id): + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + await groups_local.user_removed_from_group(group_id, user_id, {}) + else: + await self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + await self.store.maybe_delete_remote_profile_cache(user_id) + + # We kick users out in the order of: + # 1. Non-admins + # 2. Other admins + # 3. The requester + # + # This is so that if the deletion fails for some reason other admins or + # the requester still has auth to retry. + non_admins = [] + admins = [] + for u in users: + if u["user_id"] == requester_user_id: + continue + if u["is_admin"]: + admins.append(u["user_id"]) + else: + non_admins.append(u["user_id"]) + + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) + + await self.store.delete_group(group_id) + + +def _parse_join_policy_from_contents(content): + """Given a content for a request, return the specified join policy or None + """ + + join_policy_dict = content.get("m.join_policy") + if join_policy_dict: + return _parse_join_policy_dict(join_policy_dict) + else: + return None + + +def _parse_join_policy_dict(join_policy_dict): + """Given a dict for the "m.join_policy" config return the join policy specified + """ + join_policy_type = join_policy_dict.get("type") + if not join_policy_type: + return "invite" + + if join_policy_type not in ("invite", "open"): + raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") + return join_policy_type + + +def _parse_visibility_from_contents(content): + """Given a content for a request parse out whether the entity should be + public or not + """ + + visibility = content.get("m.visibility") + if visibility: + return _parse_visibility_dict(visibility) + else: + is_public = True + + return is_public + + +def _parse_visibility_dict(visibility): + """Given a dict for the "m.visibility" config return if the entity should + be public or not + """ + vis_type = visibility.get("type") + if not vis_type: + return True + + if vis_type not in ("public", "private"): + raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") + return vis_type == "public" diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/v2_alpha/test_groups.py new file mode 100644 index 00000000000..bfa9336baa7 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_groups.py @@ -0,0 +1,43 @@ +from synapse.rest.client.v1 import room +from synapse.rest.client.v2_alpha import groups + +from tests import unittest +from tests.unittest import override_config + + +class GroupsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + room_creator_user_id = "@bob:test" + + servlets = [room.register_servlets, groups.register_servlets] + + @override_config({"enable_group_creation": True}) + def test_rooms_limited_by_visibility(self): + group_id = "+spqr:test" + + # Alice creates a group + channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {"group_id": group_id}) + + # Bob creates a private room + room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) + self.helper.auth_user_id = self.room_creator_user_id + self.helper.send_state( + room_id, "m.room.name", {"name": "bob's secret room"}, tok=None + ) + self.helper.auth_user_id = self.user_id + + # Alice adds the room to her group. + channel = self.make_request( + "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} + ) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {}) + + # Alice now tries to retrieve the room list of the space. + channel = self.make_request("GET", f"/groups/{group_id}/rooms") + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals( + channel.json_body, {"chunk": [], "total_room_count_estimate": 0} + )