diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ad37b93c025..1566457b6a1 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -18,7 +18,7 @@ import warnings from datetime import datetime from hashlib import sha256 -from typing import List, Optional +from typing import List, Optional, Pattern from unpaddedbase64 import encode_base64 @@ -125,7 +125,7 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): fed_whitelist_entries = [] # Support globs (*) in whitelist values - self.federation_certificate_verification_whitelist = [] # type: List[str] + self.federation_certificate_verification_whitelist = [] # type: List[Pattern] for entry in fed_whitelist_entries: try: entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) diff --git a/synapse/config/tls.py.orig b/synapse/config/tls.py.orig new file mode 100644 index 00000000000..ad37b93c025 --- /dev/null +++ b/synapse/config/tls.py.orig @@ -0,0 +1,516 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import warnings +from datetime import datetime +from hashlib import sha256 +from typing import List, Optional + +from unpaddedbase64 import encode_base64 + +from OpenSSL import SSL, crypto +from twisted.internet._sslverify import Certificate, trustRootFromCertificates + +from synapse.config._base import Config, ConfigError +from synapse.util import glob_to_regex + +logger = logging.getLogger(__name__) + +ACME_SUPPORT_ENABLED_WARN = """\ +This server uses Synapse's built-in ACME support. Note that ACME v1 has been +deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2, +which means that this feature will not work with Synapse installs set up after +November 2019, and that it may stop working on June 2020 for installs set up +before that date. + +For more info and alternative solutions, see +https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 +--------------------------------------------------------------------------------""" + + +class TlsConfig(Config): + section = "tls" + + def read_config(self, config: dict, config_dir_path: str, **kwargs): + + acme_config = config.get("acme", None) + if acme_config is None: + acme_config = {} + + self.acme_enabled = acme_config.get("enabled", False) + + if self.acme_enabled: + logger.warning(ACME_SUPPORT_ENABLED_WARN) + + # hyperlink complains on py2 if this is not a Unicode + self.acme_url = str( + acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") + ) + self.acme_port = acme_config.get("port", 80) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"]) + self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) + self.acme_domain = acme_config.get("domain", config.get("server_name")) + + self.acme_account_key_file = self.abspath( + acme_config.get("account_key_file", config_dir_path + "/client.key") + ) + + self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) + self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) + + if self.root.server.has_tls_listener(): + if not self.tls_certificate_file: + raise ConfigError( + "tls_certificate_path must be specified if TLS-enabled listeners are " + "configured." + ) + if not self.tls_private_key_file: + raise ConfigError( + "tls_private_key_path must be specified if TLS-enabled listeners are " + "configured." + ) + + self._original_tls_fingerprints = config.get("tls_fingerprints", []) + + if self._original_tls_fingerprints is None: + self._original_tls_fingerprints = [] + + self.tls_fingerprints = list(self._original_tls_fingerprints) + + # Whether to verify certificates on outbound federation traffic + self.federation_verify_certificates = config.get( + "federation_verify_certificates", True + ) + + # Minimum TLS version to use for outbound federation traffic + self.federation_client_minimum_tls_version = str( + config.get("federation_client_minimum_tls_version", 1) + ) + + if self.federation_client_minimum_tls_version not in ["1", "1.1", "1.2", "1.3"]: + raise ConfigError( + "federation_client_minimum_tls_version must be one of: 1, 1.1, 1.2, 1.3" + ) + + # Prevent people shooting themselves in the foot here by setting it to + # the biggest number blindly + if self.federation_client_minimum_tls_version == "1.3": + if getattr(SSL, "OP_NO_TLSv1_3", None) is None: + raise ConfigError( + ( + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" + ) + ) + + # Whitelist of domains to not verify certificates for + fed_whitelist_entries = config.get( + "federation_certificate_verification_whitelist", [] + ) + if fed_whitelist_entries is None: + fed_whitelist_entries = [] + + # Support globs (*) in whitelist values + self.federation_certificate_verification_whitelist = [] # type: List[str] + for entry in fed_whitelist_entries: + try: + entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) + except UnicodeEncodeError: + raise ConfigError( + "IDNA domain names are not allowed in the " + "federation_certificate_verification_whitelist: %s" % (entry,) + ) + + # Convert globs to regex + self.federation_certificate_verification_whitelist.append(entry_regex) + + # List of custom certificate authorities for federation traffic validation + custom_ca_list = config.get("federation_custom_ca_list", None) + + # Read in and parse custom CA certificates + self.federation_ca_trust_root = None + if custom_ca_list is not None: + if len(custom_ca_list) == 0: + # A trustroot cannot be generated without any CA certificates. + # Raise an error if this option has been specified without any + # corresponding certificates. + raise ConfigError( + "federation_custom_ca_list specified without " + "any certificate files" + ) + + certs = [] + for ca_file in custom_ca_list: + logger.debug("Reading custom CA certificate file: %s", ca_file) + content = self.read_file(ca_file, "federation_custom_ca_list") + + # Parse the CA certificates + try: + cert_base = Certificate.loadPEM(content) + certs.append(cert_base) + except Exception as e: + raise ConfigError( + "Error parsing custom CA certificate file %s: %s" % (ca_file, e) + ) + + self.federation_ca_trust_root = trustRootFromCertificates(certs) + + # This config option applies to non-federation HTTP clients + # (e.g. for talking to recaptcha, identity servers, and such) + # It should never be used in production, and is intended for + # use only when running tests. + self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( + "use_insecure_ssl_client_just_for_testing_do_not_use" + ) + + self.tls_certificate = None # type: Optional[crypto.X509] + self.tls_private_key = None # type: Optional[crypto.PKey] + + def is_disk_cert_valid(self, allow_self_signed=True): + """ + Is the certificate we have on disk valid, and if so, for how long? + + Args: + allow_self_signed (bool): Should we allow the certificate we + read to be self signed? + + Returns: + int: Days remaining of certificate validity. + None: No certificate exists. + """ + if not os.path.exists(self.tls_certificate_file): + return None + + try: + with open(self.tls_certificate_file, "rb") as f: + cert_pem = f.read() + except Exception as e: + raise ConfigError( + "Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) + + try: + tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + except Exception as e: + raise ConfigError( + "Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) + + if not allow_self_signed: + if tls_certificate.get_subject() == tls_certificate.get_issuer(): + raise ValueError( + "TLS Certificate is self signed, and this is not permitted" + ) + + # YYYYMMDDhhmmssZ -- in UTC + expires_on = datetime.strptime( + tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ" + ) + now = datetime.utcnow() + days_remaining = (expires_on - now).days + return days_remaining + + def read_certificate_from_disk(self, require_cert_and_key: bool): + """ + Read the certificates and private key from disk. + + Args: + require_cert_and_key: set to True to throw an error if the certificate + and key file are not given + """ + if require_cert_and_key: + self.tls_private_key = self.read_tls_private_key() + self.tls_certificate = self.read_tls_certificate() + elif self.tls_certificate_file: + # we only need the certificate for the tls_fingerprints. Reload it if we + # can, but it's not a fatal error if we can't. + try: + self.tls_certificate = self.read_tls_certificate() + except Exception as e: + logger.info( + "Unable to read TLS certificate (%s). Ignoring as no " + "tls listeners enabled.", + e, + ) + + self.tls_fingerprints = list(self._original_tls_fingerprints) + + if self.tls_certificate: + # Check that our own certificate is included in the list of fingerprints + # and include it if it is not. + x509_certificate_bytes = crypto.dump_certificate( + crypto.FILETYPE_ASN1, self.tls_certificate + ) + sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) + sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} + if sha256_fingerprint not in sha256_fingerprints: + self.tls_fingerprints.append({"sha256": sha256_fingerprint}) + + def generate_config_section( + self, + config_dir_path, + server_name, + data_dir_path, + tls_certificate_path, + tls_private_key_path, + acme_domain, + **kwargs + ): + """If the acme_domain is specified acme will be enabled. + If the TLS paths are not specified the default will be certs in the + config directory""" + + base_key_name = os.path.join(config_dir_path, server_name) + + if bool(tls_certificate_path) != bool(tls_private_key_path): + raise ConfigError( + "Please specify both a cert path and a key path or neither." + ) + + tls_enabled = ( + "" if tls_certificate_path and tls_private_key_path or acme_domain else "#" + ) + + if not tls_certificate_path: + tls_certificate_path = base_key_name + ".tls.crt" + if not tls_private_key_path: + tls_private_key_path = base_key_name + ".tls.key" + + acme_enabled = bool(acme_domain) + acme_domain = "matrix.example.com" + + default_acme_account_file = os.path.join(data_dir_path, "acme_account.key") + + # this is to avoid the max line length. Sorrynotsorry + proxypassline = ( + "ProxyPass /.well-known/acme-challenge " + "http://localhost:8009/.well-known/acme-challenge" + ) + + # flake8 doesn't recognise that variables are used in the below string + _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file + + return ( + """\ + ## TLS ## + + # PEM-encoded X509 certificate for TLS. + # This certificate, as of Synapse 1.0, will need to be a valid and verifiable + # certificate, signed by a recognised Certificate Authority. + # + # See 'ACME support' below to enable auto-provisioning this certificate via + # Let's Encrypt. + # + # If supplying your own, be sure to use a `.pem` file that includes the + # full certificate chain including any intermediate certificates (for + # instance, if using certbot, use `fullchain.pem` as your certificate, + # not `cert.pem`). + # + %(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s" + + # PEM-encoded private key for TLS + # + %(tls_enabled)stls_private_key_path: "%(tls_private_key_path)s" + + # Whether to verify TLS server certificates for outbound federation requests. + # + # Defaults to `true`. To disable certificate verification, uncomment the + # following line. + # + #federation_verify_certificates: false + + # The minimum TLS version that will be used for outbound federation requests. + # + # Defaults to `1`. Configurable to `1`, `1.1`, `1.2`, or `1.3`. Note + # that setting this value higher than `1.2` will prevent federation to most + # of the public Matrix network: only configure it to `1.3` if you have an + # entirely private federation setup and you can ensure TLS 1.3 support. + # + #federation_client_minimum_tls_version: 1.2 + + # Skip federation certificate verification on the following whitelist + # of domains. + # + # This setting should only be used in very specific cases, such as + # federation over Tor hidden services and similar. For private networks + # of homeservers, you likely want to use a private CA instead. + # + # Only effective if federation_verify_certicates is `true`. + # + #federation_certificate_verification_whitelist: + # - lon.example.com + # - *.domain.com + # - *.onion + + # List of custom certificate authorities for federation traffic. + # + # This setting should only normally be used within a private network of + # homeservers. + # + # Note that this list will replace those that are provided by your + # operating environment. Certificates must be in PEM format. + # + #federation_custom_ca_list: + # - myCA1.pem + # - myCA2.pem + # - myCA3.pem + + # ACME support: This will configure Synapse to request a valid TLS certificate + # for your configured `server_name` via Let's Encrypt. + # + # Note that ACME v1 is now deprecated, and Synapse currently doesn't support + # ACME v2. This means that this feature currently won't work with installs set + # up after November 2019. For more info, and alternative solutions, see + # https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 + # + # Note that provisioning a certificate in this way requires port 80 to be + # routed to Synapse so that it can complete the http-01 ACME challenge. + # By default, if you enable ACME support, Synapse will attempt to listen on + # port 80 for incoming http-01 challenges - however, this will likely fail + # with 'Permission denied' or a similar error. + # + # There are a couple of potential solutions to this: + # + # * If you already have an Apache, Nginx, or similar listening on port 80, + # you can configure Synapse to use an alternate port, and have your web + # server forward the requests. For example, assuming you set 'port: 8009' + # below, on Apache, you would write: + # + # %(proxypassline)s + # + # * Alternatively, you can use something like `authbind` to give Synapse + # permission to listen on port 80. + # + acme: + # ACME support is disabled by default. Set this to `true` and uncomment + # tls_certificate_path and tls_private_key_path above to enable it. + # + enabled: %(acme_enabled)s + + # Endpoint to use to request certificates. If you only want to test, + # use Let's Encrypt's staging url: + # https://acme-staging.api.letsencrypt.org/directory + # + #url: https://acme-v01.api.letsencrypt.org/directory + + # Port number to listen on for the HTTP-01 challenge. Change this if + # you are forwarding connections through Apache/Nginx/etc. + # + port: 80 + + # Local addresses to listen on for incoming connections. + # Again, you may want to change this if you are forwarding connections + # through Apache/Nginx/etc. + # + bind_addresses: ['::', '0.0.0.0'] + + # How many days remaining on a certificate before it is renewed. + # + reprovision_threshold: 30 + + # The domain that the certificate should be for. Normally this + # should be the same as your Matrix domain (i.e., 'server_name'), but, + # by putting a file at 'https:///.well-known/matrix/server', + # you can delegate incoming traffic to another server. If you do that, + # you should give the target of the delegation here. + # + # For example: if your 'server_name' is 'example.com', but + # 'https://example.com/.well-known/matrix/server' delegates to + # 'matrix.example.com', you should put 'matrix.example.com' here. + # + # If not set, defaults to your 'server_name'. + # + domain: %(acme_domain)s + + # file to use for the account key. This will be generated if it doesn't + # exist. + # + # If unspecified, we will use CONFDIR/client.key. + # + account_key_file: %(default_acme_account_file)s + + # List of allowed TLS fingerprints for this server to publish along + # with the signing keys for this server. Other matrix servers that + # make HTTPS requests to this server will check that the TLS + # certificates returned by this server match one of the fingerprints. + # + # Synapse automatically adds the fingerprint of its own certificate + # to the list. So if federation traffic is handled directly by synapse + # then no modification to the list is required. + # + # If synapse is run behind a load balancer that handles the TLS then it + # will be necessary to add the fingerprints of the certificates used by + # the loadbalancers to this list if they are different to the one + # synapse is using. + # + # Homeservers are permitted to cache the list of TLS fingerprints + # returned in the key responses up to the "valid_until_ts" returned in + # key. It may be necessary to publish the fingerprints of a new + # certificate and wait until the "valid_until_ts" of the previous key + # responses have passed before deploying it. + # + # You can calculate a fingerprint from a given TLS listener via: + # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | + # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' + # or by checking matrix.org/federationtester/api/report?server_name=$host + # + #tls_fingerprints: [{"sha256": ""}] + """ + # Lowercase the string representation of boolean values + % { + x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1] + for x in locals().items() + } + ) + + def read_tls_certificate(self) -> crypto.X509: + """Reads the TLS certificate from the configured file, and returns it + + Also checks if it is self-signed, and warns if so + + Returns: + The certificate + """ + cert_path = self.tls_certificate_file + logger.info("Loading TLS certificate from %s", cert_path) + cert_pem = self.read_file(cert_path, "tls_certificate_path") + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + + # Check if it is self-signed, and issue a warning if so. + if cert.get_issuer() == cert.get_subject(): + warnings.warn( + ( + "Self-signed TLS certificates will not be accepted by Synapse 1.0. " + "Please either provide a valid certificate, or use Synapse's ACME " + "support to provision one." + ) + ) + + return cert + + def read_tls_private_key(self) -> crypto.PKey: + """Reads the TLS private key from the configured file, and returns it + + Returns: + The private key + """ + private_key_path = self.tls_private_key_file + logger.info("Loading TLS key from %s", private_key_path) + private_key_pem = self.read_file(private_key_path, "tls_private_key_path") + return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index f9a0f402216..3ce24de7d59 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -333,6 +333,13 @@ async def get_rooms_in_group( requester_user_id, group_id ) + # Note! room_results["is_public"] is about whether the room is considered + # public from the group's point of view. (i.e. whether non-group members + # should be able to see the room is in the group). + # This is not the same as whether the room itself is public (in the sense + # of being visible in the room directory). + # As such, room_results["is_public"] itself is not sufficient to determine + # whether any given user is permitted to see the room's metadata. room_results = await self.store.get_rooms_in_group( group_id, include_private=is_user_in_group ) @@ -342,8 +349,15 @@ async def get_rooms_in_group( room_id = room_result["room_id"] joined_users = await self.store.get_users_in_room(room_id) + + # check the user is actually allowed to see the room before showing it to them + allow_private = requester_user_id in joined_users + entry = await self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True + room_id, + len(joined_users), + with_alias=False, + allow_private=allow_private, ) if not entry: @@ -355,7 +369,7 @@ async def get_rooms_in_group( chunk.sort(key=lambda e: -e["num_joined_members"]) - return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + return {"chunk": chunk, "total_room_count_estimate": len(chunk)} class GroupsServerHandler(GroupsServerWorkerHandler): diff --git a/synapse/groups/groups_server.py.orig b/synapse/groups/groups_server.py.orig new file mode 100644 index 00000000000..f9a0f402216 --- /dev/null +++ b/synapse/groups/groups_server.py.orig @@ -0,0 +1,1006 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Optional + +from synapse.api.errors import Codes, SynapseError +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN +from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id +from synapse.util.async_helpers import concurrently_execute + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + +logger = logging.getLogger(__name__) + + +# TODO: Allow users to "knock" or simply join depending on rules +# TODO: Federation admin APIs +# TODO: is_privileged flag to users and is_public to users and rooms +# TODO: Audit log for admins (profile updates, membership changes, users who tried +# to join but were rejected, etc) +# TODO: Flairs + + +# Note that the maximum lengths are somewhat arbitrary. +MAX_SHORT_DESC_LEN = 1000 +MAX_LONG_DESC_LEN = 10000 + + +class GroupsServerWorkerHandler: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.signing_key + self.server_name = hs.hostname + self.attestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.profile_handler = hs.get_profile_handler() + + async def check_group_is_ours( + self, + group_id: str, + requester_user_id: str, + and_exists: bool = False, + and_is_admin: Optional[str] = None, + ) -> Optional[dict]: + """Check that the group is ours, and optionally if it exists. + + If group does exist then return group. + + Args: + group_id: The group ID to check. + requester_user_id: The user ID of the requester. + and_exists: whether to also check if group exists + and_is_admin: whether to also check if given str is a user_id + that is an admin + """ + if not self.is_mine_id(group_id): + raise SynapseError(400, "Group not on this server") + + group = await self.store.get_group(group_id) + if and_exists and not group: + raise SynapseError(404, "Unknown group") + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + + if and_is_admin: + is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + return group + + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the summary for a group as seen by requester_user_id. + + The group summary consists of the profile of the room, and a curated + list of users and rooms. These list *may* be organised by role/category. + The roles/categories are ordered, and so are the users/rooms within them. + + A user/room may appear in multiple roles/categories. + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + profile = await self.get_group_profile(group_id, requester_user_id) + + users, roles = await self.store.get_users_for_summary_by_role( + group_id, include_private=is_user_in_group + ) + + # TODO: Add profiles to users + + rooms, categories = await self.store.get_rooms_for_summary_by_category( + group_id, include_private=is_user_in_group + ) + + for room_entry in rooms: + room_id = room_entry["room_id"] + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( + room_id, len(joined_users), with_alias=False, allow_private=True + ) + if entry is None: + continue + entry = dict(entry) # so we don't change what's cached + entry.pop("room_id", None) + + room_entry["profile"] = entry + + rooms.sort(key=lambda e: e.get("order", 0)) + + for user in users: + user_id = user["user_id"] + + if not self.is_mine_id(requester_user_id): + attestation = await self.store.get_remote_attestation(group_id, user_id) + if not attestation: + continue + + user["attestation"] = attestation + else: + user["attestation"] = self.attestations.create_attestation( + group_id, user_id + ) + + user_profile = await self.profile_handler.get_profile_from_cache(user_id) + user.update(user_profile) + + users.sort(key=lambda e: e.get("order", 0)) + + membership_info = await self.store.get_users_membership_info_in_group( + group_id, requester_user_id + ) + + return { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } + + async def get_group_categories( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get all categories in a group (as seen by user)""" + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = await self.store.get_group_categories(group_id=group_id) + return {"categories": categories} + + async def get_group_category( + self, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: + """Get a specific category in a group (as seen by user)""" + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + return await self.store.get_group_category( + group_id=group_id, category_id=category_id + ) + + async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict: + """Get all roles in a group (as seen by user)""" + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = await self.store.get_group_roles(group_id=group_id) + return {"roles": roles} + + async def get_group_role( + self, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: + """Get a specific role in a group (as seen by user)""" + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + return await self.store.get_group_role(group_id=group_id, role_id=role_id) + + async def get_group_profile( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the group profile as seen by requester_user_id""" + + await self.check_group_is_ours(group_id, requester_user_id) + + group = await self.store.get_group(group_id) + + if group: + cols = [ + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" + + return group_description + else: + raise SynapseError(404, "Unknown group") + + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = await self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = await self.store.get_remote_attestation( + group_id, g_user_id + ) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + return {"chunk": chunk, "total_user_count_estimate": len(user_results)} + + async def get_invited_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = await self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = {"user_id": user_id} + try: + profile = await self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warning("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + + async def get_rooms_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + room_results = await self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( + room_id, len(joined_users), with_alias=False, allow_private=True + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + + +class GroupsServerHandler(GroupsServerWorkerHandler): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + async def update_group_summary_room( + self, + group_id: str, + requester_user_id: str, + room_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: + """Add/update a room to the group summary""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + RoomID.from_string(room_id) # Ensure valid room id + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_room_to_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + order=order, + is_public=is_public, + ) + + return {} + + async def delete_group_summary_room( + self, group_id: str, requester_user_id: str, room_id: str, category_id: str + ) -> JsonDict: + """Remove a room from the summary""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_room_from_summary( + group_id=group_id, room_id=room_id, category_id=category_id + ) + + return {} + + async def set_group_join_policy( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """Sets the group join policy. + + Currently supported policies are: + - "invite": an invite must be received and accepted in order to join. + - "open": anyone can join. + """ + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + join_policy = _parse_join_policy_from_contents(content) + if join_policy is None: + raise SynapseError(400, "No value specified for 'm.join_policy'") + + await self.store.set_group_join_policy(group_id, join_policy=join_policy) + + return {} + + async def update_group_category( + self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict + ) -> JsonDict: + """Add/Update a group category""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + profile = content.get("profile") + + await self.store.upsert_group_category( + group_id=group_id, + category_id=category_id, + is_public=is_public, + profile=profile, + ) + + return {} + + async def delete_group_category( + self, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: + """Delete a group category""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_group_category( + group_id=group_id, category_id=category_id + ) + + return {} + + async def update_group_role( + self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict + ) -> JsonDict: + """Add/update a role in a group""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + profile = content.get("profile") + + await self.store.upsert_group_role( + group_id=group_id, role_id=role_id, is_public=is_public, profile=profile + ) + + return {} + + async def delete_group_role( + self, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: + """Remove role from group""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_group_role(group_id=group_id, role_id=role_id) + + return {} + + async def update_group_summary_user( + self, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: + """Add/update a users entry in the group summary""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_user_to_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + order=order, + is_public=is_public, + ) + + return {} + + async def delete_group_summary_user( + self, group_id: str, requester_user_id: str, user_id: str, role_id: str + ) -> JsonDict: + """Remove a user from the group summary""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_user_from_summary( + group_id=group_id, user_id=user_id, role_id=role_id + ) + + return {} + + async def update_group_profile( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> None: + """Update the group profile""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + profile = {} + for keyname, max_length in ( + ("name", MAX_DISPLAYNAME_LEN), + ("avatar_url", MAX_AVATAR_URL_LEN), + ("short_description", MAX_SHORT_DESC_LEN), + ("long_description", MAX_LONG_DESC_LEN), + ): + if keyname in content: + value = content[keyname] + if not isinstance(value, str): + raise SynapseError( + 400, + "%r value is not a string" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) + if len(value) > max_length: + raise SynapseError( + 400, + "Invalid %s parameter" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) + profile[keyname] = value + + await self.store.update_group_profile(group_id, profile) + + async def add_room_to_group( + self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict + ) -> JsonDict: + """Add room to group""" + RoomID.from_string(room_id) # Ensure valid room id + + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_room_to_group(group_id, room_id, is_public=is_public) + + return {} + + async def update_room_in_group( + self, + group_id: str, + requester_user_id: str, + room_id: str, + config_key: str, + content: JsonDict, + ) -> JsonDict: + """Update room in group""" + RoomID.from_string(room_id) # Ensure valid room id + + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + if config_key == "m.visibility": + is_public = _parse_visibility_dict(content) + + await self.store.update_room_in_group_visibility( + group_id, room_id, is_public=is_public + ) + else: + raise SynapseError(400, "Unknown config option") + + return {} + + async def remove_room_from_group( + self, group_id: str, requester_user_id: str, room_id: str + ) -> JsonDict: + """Remove room from group""" + await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + await self.store.remove_room_from_group(group_id, room_id) + + return {} + + async def invite_to_group( + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """Invite user to group""" + + group = await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + if not group: + raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE) + + # TODO: Check if user knocked + + invited_users = await self.store.get_invited_users_in_group(group_id) + if user_id in invited_users: + raise SynapseError( + 400, "User already invited to group", errcode=Codes.BAD_STATE + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=True + ) + if user_id in (user_result["user_id"] for user_result in user_results): + raise SynapseError(400, "User already in group") + + content = { + "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, + "inviter": requester_user_id, + } + + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot invites users to groups." + res = await groups_local.on_invite(group_id, user_id, content) + local_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content.update({"attestation": local_attestation}) + + res = await self.transport_client.invite_to_group_notification( + get_domain_from_id(user_id), group_id, user_id, content + ) + + user_profile = res.get("user_profile", {}) + await self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + if res["state"] == "join": + if not self.hs.is_mine_id(user_id): + remote_attestation = res["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=user_id, group_id=group_id + ) + else: + remote_attestation = None + + await self.store.add_user_to_group( + group_id, + user_id, + is_admin=False, + is_public=False, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + return {"state": "join"} + elif res["state"] == "invite": + await self.store.add_group_invite(group_id, user_id) + return {"state": "invite"} + elif res["state"] == "reject": + return {"state": "reject"} + else: + raise SynapseError(502, "Unknown state returned by HS") + + async def _add_user( + self, group_id: str, user_id: str, content: JsonDict + ) -> Optional[JsonDict]: + """Add a user to a group based on a content dict. + + See accept_invite, join_group. + """ + if not self.hs.is_mine_id(user_id): + local_attestation = self.attestations.create_attestation( + group_id, user_id + ) # type: Optional[JsonDict] + + remote_attestation = content["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=user_id, group_id=group_id + ) + else: + local_attestation = None + remote_attestation = None + + is_public = _parse_visibility_from_contents(content) + + await self.store.add_user_to_group( + group_id, + user_id, + is_admin=False, + is_public=is_public, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + return local_attestation + + async def accept_invite( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = await self.store.is_user_invited_to_local_group( + group_id, requester_user_id + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + local_attestation = await self._add_user(group_id, requester_user_id, content) + + return {"state": "join", "attestation": local_attestation} + + async def join_group( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """User tries to join the group. + + This will error if the group requires an invite/knock to join + """ + + group_info = await self.check_group_is_ours( + group_id, requester_user_id, and_exists=True + ) + if not group_info: + raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND) + if group_info["join_policy"] != "open": + raise SynapseError(403, "Group is not publicly joinable") + + local_attestation = await self._add_user(group_id, requester_user_id, content) + + return {"state": "join", "attestation": local_attestation} + + async def remove_user_from_group( + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """Remove a user from the group; either a user is leaving or an admin + kicked them. + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_kick = False + if requester_user_id != user_id: + is_admin = await self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + is_kick = True + + await self.store.remove_user_from_group(group_id, user_id) + + if is_kick: + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot remove users from groups." + await groups_local.user_removed_from_group(group_id, user_id, {}) + else: + await self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + + if not self.hs.is_mine_id(user_id): + await self.store.maybe_delete_remote_profile_cache(user_id) + + # Delete group if the last user has left + users = await self.store.get_users_in_group(group_id, include_private=True) + if not users: + await self.store.delete_group(group_id) + + return {} + + async def create_group( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + logger.info("Attempting to create group with ID: %r", group_id) + + # parsing the id into a GroupID validates it. + group_id_obj = GroupID.from_string(group_id) + + group = await self.check_group_is_ours(group_id, requester_user_id) + if group: + raise SynapseError(400, "Group already exists") + + is_admin = await self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) + if not is_admin: + if not self.hs.config.enable_group_creation: + raise SynapseError( + 403, "Only a server admin can create groups on this server" + ) + localpart = group_id_obj.localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" + % (self.hs.config.group_creation_prefix,), + ) + + profile = content.get("profile", {}) + name = profile.get("name") + avatar_url = profile.get("avatar_url") + short_description = profile.get("short_description") + long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) + + await self.store.create_group( + group_id, + requester_user_id, + name=name, + avatar_url=avatar_url, + short_description=short_description, + long_description=long_description, + ) + + if not self.hs.is_mine_id(requester_user_id): + remote_attestation = content["attestation"] + + await self.attestations.verify_attestation( + remote_attestation, user_id=requester_user_id, group_id=group_id + ) + + local_attestation = self.attestations.create_attestation( + group_id, requester_user_id + ) # type: Optional[JsonDict] + else: + local_attestation = None + remote_attestation = None + + await self.store.add_user_to_group( + group_id, + requester_user_id, + is_admin=True, + is_public=True, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + if not self.hs.is_mine_id(requester_user_id): + await self.store.add_remote_profile_cache( + requester_user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + return {"group_id": group_id} + + async def delete_group(self, group_id: str, requester_user_id: str) -> None: + """Deletes a group, kicking out all current members. + + Only group admins or server admins can call this request + + Args: + group_id: The group ID to delete. + requester_user_id: The user requesting to delete the group. + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + # Only server admins or group admins can delete groups. + + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) + + if not is_admin: + is_admin = await self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) + + if not is_admin: + raise SynapseError(403, "User is not an admin") + + # Before deleting the group lets kick everyone out of it + users = await self.store.get_users_in_group(group_id, include_private=True) + + async def _kick_user_from_group(user_id): + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot kick users from groups." + await groups_local.user_removed_from_group(group_id, user_id, {}) + else: + await self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + await self.store.maybe_delete_remote_profile_cache(user_id) + + # We kick users out in the order of: + # 1. Non-admins + # 2. Other admins + # 3. The requester + # + # This is so that if the deletion fails for some reason other admins or + # the requester still has auth to retry. + non_admins = [] + admins = [] + for u in users: + if u["user_id"] == requester_user_id: + continue + if u["is_admin"]: + admins.append(u["user_id"]) + else: + non_admins.append(u["user_id"]) + + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) + + await self.store.delete_group(group_id) + + +def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]: + """Given a content for a request, return the specified join policy or None""" + + join_policy_dict = content.get("m.join_policy") + if join_policy_dict: + return _parse_join_policy_dict(join_policy_dict) + else: + return None + + +def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str: + """Given a dict for the "m.join_policy" config return the join policy specified""" + join_policy_type = join_policy_dict.get("type") + if not join_policy_type: + return "invite" + + if join_policy_type not in ("invite", "open"): + raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") + return join_policy_type + + +def _parse_visibility_from_contents(content: JsonDict) -> bool: + """Given a content for a request parse out whether the entity should be + public or not + """ + + visibility = content.get("m.visibility") + if visibility: + return _parse_visibility_dict(visibility) + else: + is_public = True + + return is_public + + +def _parse_visibility_dict(visibility: JsonDict) -> bool: + """Given a dict for the "m.visibility" config return if the entity should + be public or not + """ + vis_type = visibility.get("type") + if not vis_type: + return True + + if vis_type not in ("public", "private"): + raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") + return vis_type == "public" diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index ba1877adcd9..34e66436396 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -20,6 +20,7 @@ from synapse.events import EventBase from synapse.types import UserID +from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ def _contains_display_name(self, display_name: str) -> bool: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = _re_word_boundary(r1) + r1 = re_word_boundary(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = _glob_to_re(glob, word_boundary) + r = glob_to_regex(glob, word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: @@ -221,56 +222,6 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: return False -def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: - """Generates regex for a given glob. - - Args: - glob - word_boundary: Whether to match against word boundaries or entire string. - """ - if IS_GLOB.search(glob): - r = re.escape(glob) - - r = r.replace(r"\*", ".*?") - r = r.replace(r"\?", ".") - - # handle [abc], [a-z] and [!a-z] style ranges. - r = GLOB_REGEX.sub( - lambda x: ( - "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-")) - ), - r, - ) - if word_boundary: - r = _re_word_boundary(r) - - return re.compile(r, flags=re.IGNORECASE) - else: - r = "^" + r + "$" - - return re.compile(r, flags=re.IGNORECASE) - elif word_boundary: - r = re.escape(glob) - r = _re_word_boundary(r) - - return re.compile(r, flags=re.IGNORECASE) - else: - r = "^" + re.escape(glob) + "$" - return re.compile(r, flags=re.IGNORECASE) - - -def _re_word_boundary(r: str) -> str: - """ - Adds word boundary characters to the start and end of an - expression to require that the match occur as a whole word, - but do so respecting the fact that strings starting or ending - with non-word characters will change word boundaries. - """ - # we can't use \b as it chokes on unicode. however \W seems to be okay - # as shorthand for [^0-9A-Za-z_]. - return r"(^|\W)%s(\W|$)" % (r,) - - def _flatten_dict( d: Union[EventBase, dict], prefix: Optional[List[str]] = None, diff --git a/synapse/push/push_rule_evaluator.py.orig b/synapse/push/push_rule_evaluator.py.orig new file mode 100644 index 00000000000..ba1877adcd9 --- /dev/null +++ b/synapse/push/push_rule_evaluator.py.orig @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union + +from synapse.events import EventBase +from synapse.types import UserID +from synapse.util.caches.lrucache import LruCache + +logger = logging.getLogger(__name__) + + +GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") +IS_GLOB = re.compile(r"[\?\*\[\]]") +INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") + + +def _room_member_count( + ev: EventBase, condition: Dict[str, Any], room_member_count: int +) -> bool: + return _test_ineq_condition(condition, room_member_count) + + +def _sender_notification_permission( + ev: EventBase, + condition: Dict[str, Any], + sender_power_level: int, + power_levels: Dict[str, Union[int, Dict[str, int]]], +) -> bool: + notif_level_key = condition.get("key") + if notif_level_key is None: + return False + + notif_levels = power_levels.get("notifications", {}) + assert isinstance(notif_levels, dict) + room_notif_level = notif_levels.get(notif_level_key, 50) + + return sender_power_level >= room_notif_level + + +def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool: + if "is" not in condition: + return False + m = INEQUALITY_EXPR.match(condition["is"]) + if not m: + return False + ineq = m.group(1) + rhs = m.group(2) + if not rhs.isdigit(): + return False + rhs_int = int(rhs) + + if ineq == "" or ineq == "==": + return number == rhs_int + elif ineq == "<": + return number < rhs_int + elif ineq == ">": + return number > rhs_int + elif ineq == ">=": + return number >= rhs_int + elif ineq == "<=": + return number <= rhs_int + else: + return False + + +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) + return tweaks + + +class PushRuleEvaluatorForEvent: + def __init__( + self, + event: EventBase, + room_member_count: int, + sender_power_level: int, + power_levels: Dict[str, Union[int, Dict[str, int]]], + ): + self._event = event + self._room_member_count = room_member_count + self._sender_power_level = sender_power_level + self._power_levels = power_levels + + # Maps strings of e.g. 'content.body' -> event["content"]["body"] + self._value_cache = _flatten_dict(event) + + def matches( + self, condition: Dict[str, Any], user_id: str, display_name: str + ) -> bool: + if condition["kind"] == "event_match": + return self._event_match(condition, user_id) + elif condition["kind"] == "contains_display_name": + return self._contains_display_name(display_name) + elif condition["kind"] == "room_member_count": + return _room_member_count(self._event, condition, self._room_member_count) + elif condition["kind"] == "sender_notification_permission": + return _sender_notification_permission( + self._event, condition, self._sender_power_level, self._power_levels + ) + else: + return True + + def _event_match(self, condition: dict, user_id: str) -> bool: + pattern = condition.get("pattern", None) + + if not pattern: + pattern_type = condition.get("pattern_type", None) + if pattern_type == "user_id": + pattern = user_id + elif pattern_type == "user_localpart": + pattern = UserID.from_string(user_id).localpart + + if not pattern: + logger.warning("event_match condition with no pattern") + return False + + # XXX: optimisation: cache our pattern regexps + if condition["key"] == "content.body": + body = self._event.content.get("body", None) + if not body or not isinstance(body, str): + return False + + return _glob_matches(pattern, body, word_boundary=True) + else: + haystack = self._get_value(condition["key"]) + if haystack is None: + return False + + return _glob_matches(pattern, haystack) + + def _contains_display_name(self, display_name: str) -> bool: + if not display_name: + return False + + body = self._event.content.get("body", None) + if not body or not isinstance(body, str): + return False + + # Similar to _glob_matches, but do not treat display_name as a glob. + r = regex_cache.get((display_name, False, True), None) + if not r: + r1 = re.escape(display_name) + r1 = _re_word_boundary(r1) + r = re.compile(r1, flags=re.IGNORECASE) + regex_cache[(display_name, False, True)] = r + + return bool(r.search(body)) + + def _get_value(self, dotted_key: str) -> Optional[str]: + return self._value_cache.get(dotted_key, None) + + +# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches +regex_cache = LruCache( + 50000, "regex_push_cache" +) # type: LruCache[Tuple[str, bool, bool], Pattern] + + +def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: + """Tests if value matches glob. + + Args: + glob + value: String to test against glob. + word_boundary: Whether to match against word boundaries or entire + string. Defaults to False. + """ + + try: + r = regex_cache.get((glob, True, word_boundary), None) + if not r: + r = _glob_to_re(glob, word_boundary) + regex_cache[(glob, True, word_boundary)] = r + return bool(r.search(value)) + except re.error: + logger.warning("Failed to parse glob to regex: %r", glob) + return False + + +def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: + """Generates regex for a given glob. + + Args: + glob + word_boundary: Whether to match against word boundaries or entire string. + """ + if IS_GLOB.search(glob): + r = re.escape(glob) + + r = r.replace(r"\*", ".*?") + r = r.replace(r"\?", ".") + + # handle [abc], [a-z] and [!a-z] style ranges. + r = GLOB_REGEX.sub( + lambda x: ( + "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-")) + ), + r, + ) + if word_boundary: + r = _re_word_boundary(r) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + r + "$" + + return re.compile(r, flags=re.IGNORECASE) + elif word_boundary: + r = re.escape(glob) + r = _re_word_boundary(r) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + re.escape(glob) + "$" + return re.compile(r, flags=re.IGNORECASE) + + +def _re_word_boundary(r: str) -> str: + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) + + +def _flatten_dict( + d: Union[EventBase, dict], + prefix: Optional[List[str]] = None, + result: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + if prefix is None: + prefix = [] + if result is None: + result = {} + for key, value in d.items(): + if isinstance(value, str): + result[".".join(prefix + [key])] = value.lower() + elif hasattr(value, "items"): + _flatten_dict(value, prefix=(prefix + [key]), result=result) + + return result diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 517686f0a67..1e58aeae436 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -16,6 +16,7 @@ import json import logging import re +from typing import Pattern import attr from frozendict import frozendict @@ -27,6 +28,9 @@ logger = logging.getLogger(__name__) +_WILDCARD_RUN = re.compile(r"([\?\*]+)") + + def _reject_invalid_json(val): """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -159,25 +163,54 @@ def log_failure(failure, msg, consumeErrors=True): return failure -def glob_to_regex(glob): +def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: """Converts a glob to a compiled regex object. - The regex is anchored at the beginning and end of the string. - Args: - glob (str) + glob: pattern to match + word_boundary: If True, the pattern will be allowed to match at word boundaries + anywhere in the string. Otherwise, the pattern is anchored at the start and + end of the string. Returns: - re.RegexObject + compiled regex pattern """ - res = "" - for c in glob: - if c == "*": - res = res + ".*" - elif c == "?": - res = res + "." + + # Patterns with wildcards must be simplified to avoid performance cliffs + # - The glob `?**?**?` is equivalent to the glob `???*` + # - The glob `???*` is equivalent to the regex `.{3,}` + chunks = [] + for chunk in _WILDCARD_RUN.split(glob): + # No wildcards? re.escape() + if not _WILDCARD_RUN.match(chunk): + chunks.append(re.escape(chunk)) + continue + + # Wildcards? Simplify. + qmarks = chunk.count("?") + if "*" in chunk: + chunks.append(".{%d,}" % qmarks) else: - res = res + re.escape(c) + chunks.append(".{%d}" % qmarks) + + res = "".join(chunks) - # \A anchors at start of string, \Z at end of string - return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) + if word_boundary: + res = re_word_boundary(res) + else: + # \A anchors at start of string, \Z at end of string + res = r"\A" + res + r"\Z" + + return re.compile(res, re.IGNORECASE) + + +def re_word_boundary(r: str) -> str: + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) diff --git a/synapse/util/__init__.py.orig b/synapse/util/__init__.py.orig new file mode 100644 index 00000000000..517686f0a67 --- /dev/null +++ b/synapse/util/__init__.py.orig @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import re + +import attr +from frozendict import frozendict + +from twisted.internet import defer, task + +from synapse.logging import context + +logger = logging.getLogger(__name__) + + +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise ValueError("Invalid JSON value: '%s'" % val) + + +def _handle_frozendict(obj): + """Helper for json_encoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) + + +# A custom JSON encoder which: +# * handles frozendicts +# * produces valid JSON (no NaNs etc) +# * reduces redundant whitespace +json_encoder = json.JSONEncoder( + allow_nan=False, separators=(",", ":"), default=_handle_frozendict +) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) + + +def unwrapFirstError(failure): + # defer.gatherResults and DeferredLists wrap failures. + failure.trap(defer.FirstError) + return failure.value.subFailure + + +@attr.s(slots=True) +class Clock: + """ + A Clock wraps a Twisted reactor and provides utilities on top of it. + + Args: + reactor: The Twisted reactor to use. + """ + + _reactor = attr.ib() + + @defer.inlineCallbacks + def sleep(self, seconds): + d = defer.Deferred() + with context.PreserveLoggingContext(): + self._reactor.callLater(seconds, d.callback, seconds) + res = yield d + return res + + def time(self): + """Returns the current system time in seconds since epoch.""" + return self._reactor.seconds() + + def time_msec(self): + """Returns the current system time in milliseconds since epoch.""" + return int(self.time() * 1000) + + def looping_call(self, f, msec, *args, **kwargs): + """Call a function repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. + + Args: + f(function): The function to call repeatedly. + msec(float): How long to wait between calls in milliseconds. + *args: Postional arguments to pass to function. + **kwargs: Key arguments to pass to function. + """ + call = task.LoopingCall(f, *args, **kwargs) + call.clock = self._reactor + d = call.start(msec / 1000.0, now=False) + d.addErrback(log_failure, "Looping call died", consumeErrors=False) + return call + + def call_later(self, delay, callback, *args, **kwargs): + """Call something later + + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. + + Args: + delay(float): How long to wait in seconds. + callback(function): Function to call + *args: Postional arguments to pass to function. + **kwargs: Key arguments to pass to function. + """ + + def wrapped_callback(*args, **kwargs): + with context.PreserveLoggingContext(): + callback(*args, **kwargs) + + with context.PreserveLoggingContext(): + return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) + + def cancel_call_later(self, timer, ignore_errs=False): + try: + timer.cancel() + except Exception: + if not ignore_errs: + raise + + +def log_failure(failure, msg, consumeErrors=True): + """Creates a function suitable for passing to `Deferred.addErrback` that + logs any failures that occur. + + Args: + msg (str): Message to log + consumeErrors (bool): If true consumes the failure, otherwise passes + on down the callback chain + + Returns: + func(Failure) + """ + + logger.error( + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) + ) + + if not consumeErrors: + return failure + + +def glob_to_regex(glob): + """Converts a glob to a compiled regex object. + + The regex is anchored at the beginning and end of the string. + + Args: + glob (str) + + Returns: + re.RegexObject + """ + res = "" + for c in glob: + if c == "*": + res = res + ".*" + elif c == "?": + res = res + "." + else: + res = res + re.escape(c) + + # \A anchors at start of string, \Z at end of string + return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index cfeccc05779..c2a18fd2453 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -75,6 +75,25 @@ def test_block_ip_literals(self): self.assertFalse(server_matches_acl_event("[1:2::]", e)) self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + def test_wildcard_matching(self): + e = _create_acl_event({"allow": ["good*.com"]}) + self.assertTrue( + server_matches_acl_event("good.com", e), + "* matches 0 characters", + ) + self.assertTrue( + server_matches_acl_event("GOOD.COM", e), + "pattern is case-insensitive", + ) + self.assertTrue( + server_matches_acl_event("good.aa.com", e), + "* matches several characters, including '.'", + ) + self.assertFalse( + server_matches_acl_event("ishgood.com", e), + "pattern does not allow prefixes", + ) + class StateQueryTests(unittest.FederatingHomeserverTestCase): diff --git a/tests/federation/test_federation_server.py.orig b/tests/federation/test_federation_server.py.orig new file mode 100644 index 00000000000..cfeccc05779 --- /dev/null +++ b/tests/federation/test_federation_server.py.orig @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from parameterized import parameterized + +from synapse.events import make_event_from_dict +from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class FederationServerTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)]) + def test_bad_request(self, query_content): + """ + Querying with bad data returns a reasonable error code. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + "/get_missing_events/(?P[^/]*)/?" + + channel = self.make_request( + "POST", + "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), + query_content, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") + + +class ServerACLsTestCase(unittest.TestCase): + def test_blacklisted_server(self): + e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("evil.com", e)) + self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + + self.assertTrue(server_matches_acl_event("evil.com.au", e)) + self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + + def test_block_ip_literals(self): + e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("1.2.3.4", e)) + self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) + self.assertFalse(server_matches_acl_event("[1:2::]", e)) + self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + + +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/ without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, {"@user:other.example.com", u1}) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/ requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + +def _create_acl_event(content): + return make_event_from_dict( + { + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content, + } + ) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 4a841f5bb84..21f1e3e0aac 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator @@ -67,6 +69,170 @@ def test_display_name(self): # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + def _assert_matches( + self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + ) -> None: + evaluator = self._get_evaluator(content) + self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) + + def _assert_not_matches( + self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + ) -> None: + evaluator = self._get_evaluator(content) + self.assertFalse( + evaluator.matches(condition, "@user:test", "display_name"), msg + ) + + def test_event_match_body(self): + """Check that event_match conditions on content.body work as expected""" + + # if the key is `content.body`, the pattern matches substrings. + + # non-wildcards should match + condition = { + "kind": "event_match", + "key": "content.body", + "pattern": "foobaz", + } + self._assert_matches( + condition, + {"body": "aaa FoobaZ zzz"}, + "patterns should match and be case-insensitive", + ) + self._assert_not_matches( + condition, + {"body": "aa xFoobaZ yy"}, + "pattern should only match at word boundaries", + ) + self._assert_not_matches( + condition, + {"body": "aa foobazx yy"}, + "pattern should only match at word boundaries", + ) + + # wildcards should match + condition = { + "kind": "event_match", + "key": "content.body", + "pattern": "f?o*baz", + } + + self._assert_matches( + condition, + {"body": "aaa FoobarbaZ zzz"}, + "* should match string and pattern should be case-insensitive", + ) + self._assert_matches( + condition, {"body": "aa foobaz yy"}, "* should match 0 characters" + ) + self._assert_not_matches( + condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters" + ) + self._assert_not_matches( + condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters" + ) + self._assert_not_matches( + condition, + {"body": "aa xfooxbaz yy"}, + "pattern should only match at word boundaries", + ) + self._assert_not_matches( + condition, + {"body": "aa fooxbazx yy"}, + "pattern should only match at word boundaries", + ) + + # test backslashes + condition = { + "kind": "event_match", + "key": "content.body", + "pattern": r"f\oobaz", + } + self._assert_matches( + condition, + {"body": r"F\oobaz"}, + "backslash should match itself", + ) + condition = { + "kind": "event_match", + "key": "content.body", + "pattern": r"f\?obaz", + } + self._assert_matches( + condition, + {"body": r"F\oobaz"}, + r"? after \ should match any character", + ) + + def test_event_match_non_body(self): + """Check that event_match conditions on other keys work as expected""" + + # if the key is anything other than 'content.body', the pattern must match the + # whole value. + + # non-wildcards should match + condition = { + "kind": "event_match", + "key": "content.value", + "pattern": "foobaz", + } + self._assert_matches( + condition, + {"value": "FoobaZ"}, + "patterns should match and be case-insensitive", + ) + self._assert_not_matches( + condition, + {"value": "xFoobaZ"}, + "pattern should only match at the start/end of the value", + ) + self._assert_not_matches( + condition, + {"value": "FoobaZz"}, + "pattern should only match at the start/end of the value", + ) + + # wildcards should match + condition = { + "kind": "event_match", + "key": "content.value", + "pattern": "f?o*baz", + } + self._assert_matches( + condition, + {"value": "FoobarbaZ"}, + "* should match string and pattern should be case-insensitive", + ) + self._assert_matches( + condition, {"value": "foobaz"}, "* should match 0 characters" + ) + self._assert_not_matches( + condition, {"value": "fobbaz"}, "? should not match 0 characters" + ) + self._assert_not_matches( + condition, {"value": "fiiobaz"}, "? should not match 2 characters" + ) + self._assert_not_matches( + condition, + {"value": "xfooxbaz"}, + "pattern should only match at the start/end of the value", + ) + self._assert_not_matches( + condition, + {"value": "fooxbazx"}, + "pattern should only match at the start/end of the value", + ) + self._assert_not_matches( + condition, + {"value": "x\nfooxbaz"}, + "pattern should not match after a newline", + ) + self._assert_not_matches( + condition, + {"value": "fooxbaz\nx"}, + "pattern should not match before a newline", + ) + def test_no_body(self): """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) diff --git a/tests/push/test_push_rule_evaluator.py.orig b/tests/push/test_push_rule_evaluator.py.orig new file mode 100644 index 00000000000..4a841f5bb84 --- /dev/null +++ b/tests/push/test_push_rule_evaluator.py.orig @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push import push_rule_evaluator +from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent + +from tests import unittest + + +class PushRuleEvaluatorTestCase(unittest.TestCase): + def _get_evaluator(self, content): + event = FrozenEvent( + { + "event_id": "$event_id", + "type": "m.room.history_visibility", + "sender": "@user:test", + "state_key": "", + "room_id": "#room:test", + "content": content, + }, + RoomVersions.V1, + ) + room_member_count = 0 + sender_power_level = 0 + power_levels = {} + return PushRuleEvaluatorForEvent( + event, room_member_count, sender_power_level, power_levels + ) + + def test_display_name(self): + """Check for a matching display name in the body of the event.""" + evaluator = self._get_evaluator({"body": "foo bar baz"}) + + condition = { + "kind": "contains_display_name", + } + + # Blank names are skipped. + self.assertFalse(evaluator.matches(condition, "@user:test", "")) + + # Check a display name that doesn't match. + self.assertFalse(evaluator.matches(condition, "@user:test", "not found")) + + # Check a display name which matches. + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # A display name that matches, but not a full word does not result in a match. + self.assertFalse(evaluator.matches(condition, "@user:test", "ba")) + + # A display name should not be interpreted as a regular expression. + self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]")) + + # A display name with spaces should work fine. + self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + + def test_no_body(self): + """Not having a body shouldn't break the evaluator.""" + evaluator = self._get_evaluator({}) + + condition = { + "kind": "contains_display_name", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_invalid_body(self): + """A non-string body should not break the evaluator.""" + condition = { + "kind": "contains_display_name", + } + + for body in (1, True, {"foo": "bar"}): + evaluator = self._get_evaluator({"body": body}) + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_tweaks_for_actions(self): + """ + This tests the behaviour of tweaks_for_actions. + """ + + actions = [ + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + "notify", + ] + + self.assertEqual( + push_rule_evaluator.tweaks_for_actions(actions), + {"sound": "default", "highlight": True}, + ) diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/v2_alpha/test_groups.py new file mode 100644 index 00000000000..bfa9336baa7 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_groups.py @@ -0,0 +1,43 @@ +from synapse.rest.client.v1 import room +from synapse.rest.client.v2_alpha import groups + +from tests import unittest +from tests.unittest import override_config + + +class GroupsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + room_creator_user_id = "@bob:test" + + servlets = [room.register_servlets, groups.register_servlets] + + @override_config({"enable_group_creation": True}) + def test_rooms_limited_by_visibility(self): + group_id = "+spqr:test" + + # Alice creates a group + channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {"group_id": group_id}) + + # Bob creates a private room + room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) + self.helper.auth_user_id = self.room_creator_user_id + self.helper.send_state( + room_id, "m.room.name", {"name": "bob's secret room"}, tok=None + ) + self.helper.auth_user_id = self.user_id + + # Alice adds the room to her group. + channel = self.make_request( + "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} + ) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {}) + + # Alice now tries to retrieve the room list of the space. + channel = self.make_request("GET", f"/groups/{group_id}/rooms") + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals( + channel.json_body, {"chunk": [], "total_room_count_estimate": 0} + ) diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py new file mode 100644 index 00000000000..220accb92b6 --- /dev/null +++ b/tests/util/test_glob_to_regex.py @@ -0,0 +1,59 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.util import glob_to_regex + +from tests.unittest import TestCase + + +class GlobToRegexTestCase(TestCase): + def test_literal_match(self): + """patterns without wildcards should match""" + pat = glob_to_regex("foobaz") + self.assertTrue( + pat.match("FoobaZ"), "patterns should match and be case-insensitive" + ) + self.assertFalse( + pat.match("x foobaz"), "pattern should not match at word boundaries" + ) + + def test_wildcard_match(self): + pat = glob_to_regex("f?o*baz") + + self.assertTrue( + pat.match("FoobarbaZ"), + "* should match string and pattern should be case-insensitive", + ) + self.assertTrue(pat.match("foobaz"), "* should match 0 characters") + self.assertFalse(pat.match("fooxaz"), "the character after * must match") + self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters") + self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters") + + def test_multi_wildcard(self): + """patterns with multiple wildcards in a row should match""" + pat = glob_to_regex("**baz") + self.assertTrue(pat.match("agsgsbaz"), "** should match any string") + self.assertTrue(pat.match("baz"), "** should match the empty string") + self.assertEqual(pat.pattern, r"\A.{0,}baz\Z") + + pat = glob_to_regex("*?baz") + self.assertTrue(pat.match("agsgsbaz"), "*? should match any string") + self.assertTrue(pat.match("abaz"), "*? should match a single char") + self.assertFalse(pat.match("baz"), "*? should not match the empty string") + self.assertEqual(pat.pattern, r"\A.{1,}baz\Z") + + pat = glob_to_regex("a?*?*?baz") + self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars") + self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars") + self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars") + self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")