Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor version number checks #1738

Merged
merged 14 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions esrally/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .context import RequestContextHolder, RequestContextManager
from .factory import (
EsClientFactory,
cluster_distribution_version,
create_api_key,
delete_api_keys,
wait_for_rest_layer,
Expand Down
27 changes: 20 additions & 7 deletions esrally/client/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,30 @@ async def put_lifecycle(self, *args, **kwargs):
class RallyAsyncElasticsearch(AsyncElasticsearch, RequestContextHolder):
def __init__(self, *args, **kwargs):
distribution_version = kwargs.pop("distribution_version", None)
distribution_flavor = kwargs.pop("distribution_flavor", None)
super().__init__(*args, **kwargs)
# skip verification at this point; we've already verified this earlier with the synchronous client.
# The async client is used in the hot code path and we use customized overrides (such as that we don't
# parse response bodies in some cases for performance reasons, e.g. when using the bulk API).
self._verified_elasticsearch = True
if distribution_version:
self.distribution_version = versions.Version.from_string(distribution_version)
else:
self.distribution_version = None
self.distribution_version = distribution_version
self.distribution_flavor = distribution_flavor

# some ILM method signatures changed in 'elasticsearch-py' 8.x,
# so we override method(s) here to provide BWC for any custom
# runners that aren't using the new kwargs
self.ilm = RallyIlmClient(self)

@property
def is_serverless(self):
return versions.is_serverless(self.distribution_flavor)

def options(self, *args, **kwargs):
new_self = super().options(*args, **kwargs)
new_self.distribution_version = self.distribution_version
new_self.distribution_flavor = self.distribution_flavor
return new_self

async def perform_request(
self,
method: str,
Expand Down Expand Up @@ -328,9 +337,13 @@ async def perform_request(
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
# see https://github.com/elastic/elasticsearch/issues/51816
if self.distribution_version is not None and self.distribution_version >= versions.Version.from_string("8.0.0"):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)
# Not applicable to serverless
if not self.is_serverless:
if versions.is_version_identifier(self.distribution_version) and (
versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0")
):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)

if params:
target = f"{path}?{_quote_query(params)}"
Expand Down
44 changes: 39 additions & 5 deletions esrally/client/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EsClientFactory:
compatibility guarantees that are broader than the library's defaults.
"""

def __init__(self, hosts, client_options, distribution_version=None):
def __init__(self, hosts, client_options, distribution_version=None, distribution_flavor=None):
pquentin marked this conversation as resolved.
Show resolved Hide resolved
def host_string(host):
# protocol can be set at either host or client opts level
protocol = "https" if client_options.get("use_ssl") or host.get("use_ssl") else "http"
Expand All @@ -41,8 +41,10 @@ def host_string(host):
self.client_options = dict(client_options)
self.ssl_context = None
# This attribute is necessary for the backwards-compatibility logic contained in
# RallySyncElasticsearch.perform_request() and RallyAsyncElasticsearch.perform_request().
# RallySyncElasticsearch.perform_request() and RallyAsyncElasticsearch.perform_request(), and also for
# identification of whether or not a client is 'serverless'.
self.distribution_version = distribution_version
self.distribution_flavor = distribution_flavor
self.logger = logging.getLogger(__name__)

masked_client_options = dict(client_options)
Expand Down Expand Up @@ -181,7 +183,11 @@ def create(self):
from esrally.client.synchronous import RallySyncElasticsearch

return RallySyncElasticsearch(
distribution_version=self.distribution_version, hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options
distribution_version=self.distribution_version,
distribution_flavor=self.distribution_flavor,
hosts=self.hosts,
ssl_context=self.ssl_context,
**self.client_options,
)

def create_async(self, api_key=None, client_id=None):
Expand Down Expand Up @@ -226,6 +232,7 @@ async def on_request_end(session, trace_config_ctx, params):

async_client = RallyAsyncElasticsearch(
distribution_version=self.distribution_version,
distribution_flavor=self.distribution_flavor,
hosts=self.hosts,
transport_class=RallyAsyncTransport,
ssl_context=self.ssl_context,
Expand Down Expand Up @@ -316,6 +323,32 @@ def wait_for_rest_layer(es, max_attempts=40):
return False


def cluster_distribution_version(hosts, client_options, client_factory=EsClientFactory):
"""
Attempt to get the target cluster's distribution version, build flavor, and build hash by creating and using
a 'sync' Elasticsearch client.

:param hosts: The host(s) to connect to.
:param client_options: The client options to customize the Elasticsearch client.
:param client_factory: Factory class that creates the Elasticsearch client.
:return: The cluster's build flavor, version number, and build hash. For Serverless Elasticsearch these may all be
the build flavor value.
"""
# no way for us to know whether we're talking to a serverless elasticsearch or not, so we default to the sync client
es = client_factory(hosts, client_options).create()
# unconditionally wait for the REST layer - if it's not up by then, we'll intentionally raise the original error
wait_for_rest_layer(es)
version = es.info()["version"]

version_build_flavor = version.get("build_flavor", "oss")
# build hash will only be available for serverless if the client has operator privs
version_build_hash = version.get("build_hash", version_build_flavor)
# version number does not exist for serverless
version_number = version.get("number", version_build_flavor)

return version_build_flavor, version_number, version_build_hash


def create_api_key(es, client_id, max_attempts=5):
"""
Creates an API key for the provided ``client_id``.
Expand Down Expand Up @@ -366,7 +399,8 @@ def raise_exception(failed_ids, cause=None):

# Before ES 7.10, deleting API keys by ID had to be done individually.
# After ES 7.10, a list of API key IDs can be deleted in one request.
current_version = versions.Version.from_string(es.info()["version"]["number"])
version = es.info()["version"]
current_version = versions.Version.from_string(version.get("number", "7.10.0"))
minimum_version = versions.Version.from_string("7.10.0")

deleted = []
Expand All @@ -377,7 +411,7 @@ def raise_exception(failed_ids, cause=None):
import elasticsearch

try:
if current_version >= minimum_version:
if current_version >= minimum_version or es.is_serverless:
resp = es.security.invalidate_api_key(ids=remaining)
deleted += resp["invalidated_api_keys"]
remaining = [i for i in ids if i not in deleted]
Expand Down
59 changes: 35 additions & 24 deletions esrally/client/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import re
import warnings
from typing import Any, Iterable, Mapping, Optional

Expand Down Expand Up @@ -76,44 +75,45 @@ def raise_error(cls, state, meta, body):
@classmethod
def check_product(cls, headers, response):
# type: (dict[str, str], dict[str, str]) -> int
"""Verifies that the server we're talking to is Elasticsearch.
"""
Verifies that the server we're talking to is Elasticsearch.
Does this by checking HTTP headers and the deserialized
response to the 'info' API. Returns one of the states above.
"""

version = response.get("version", {})
try:
version = response.get("version", {})
version_number = tuple(
int(x) if x is not None else 999 for x in re.search(r"^([0-9]+)\.([0-9]+)(?:\.([0-9]+))?", version["number"]).groups()
)
except (KeyError, TypeError, ValueError, AttributeError):
# No valid 'version.number' field, effectively 0.0.0
version = {}
version_number = (0, 0, 0)
version_number = versions.Version.from_string(version.get("number", None))
except TypeError:
# No valid 'version.number' field, either Serverless Elasticsearch, or not Elasticsearch at all
version_number = versions.Version.from_string("0.0.0")

build_flavor = version.get("build_flavor", None)

# Check all of the fields and headers for missing/valid values.
try:
bad_tagline = response.get("tagline", None) != "You Know, for Search"
bad_build_flavor = version.get("build_flavor", None) != "default"
bad_build_flavor = build_flavor not in ("default", "serverless")
bad_product_header = headers.get("x-elastic-product", None) != "Elasticsearch"
except (AttributeError, TypeError):
bad_tagline = True
bad_build_flavor = True
bad_product_header = True

# 7.0-7.13 and there's a bad 'tagline' or unsupported 'build_flavor'
if (7, 0, 0) <= version_number < (7, 14, 0):
if versions.Version.from_string("7.0.0") <= version_number < versions.Version.from_string("7.14.0"):
if bad_tagline:
return cls.UNSUPPORTED_PRODUCT
elif bad_build_flavor:
return cls.UNSUPPORTED_DISTRIBUTION

elif (
# No version or version less than 6.x
version_number < (6, 0, 0)
# 6.x and there's a bad 'tagline'
or ((6, 0, 0) <= version_number < (7, 0, 0) and bad_tagline)
# No version or version less than 6.8.0, and we're not talking to a serverless elasticsearch
(version_number < versions.Version.from_string("6.8.0") and not versions.is_serverless(build_flavor))
# 6.8.0 and there's a bad 'tagline'
or (versions.Version.from_string("6.8.0") <= version_number < versions.Version.from_string("7.0.0") and bad_tagline)
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
or ((7, 14, 0) <= version_number and bad_product_header)
or (versions.Version.from_string("7.14.0") <= version_number and bad_product_header)
):
return cls.UNSUPPORTED_PRODUCT

Expand All @@ -123,13 +123,21 @@ def check_product(cls, headers, response):
class RallySyncElasticsearch(Elasticsearch):
def __init__(self, *args, **kwargs):
distribution_version = kwargs.pop("distribution_version", None)
distribution_flavor = kwargs.pop("distribution_flavor", None)
super().__init__(*args, **kwargs)
self._verified_elasticsearch = None
self.distribution_version = distribution_version
self.distribution_flavor = distribution_flavor

if distribution_version:
self.distribution_version = versions.Version.from_string(distribution_version)
else:
self.distribution_version = None
@property
def is_serverless(self):
return versions.is_serverless(self.distribution_flavor)

def options(self, *args, **kwargs):
new_self = super().options(*args, **kwargs)
new_self.distribution_version = self.distribution_version
new_self.distribution_flavor = self.distribution_flavor
return new_self

def perform_request(
self,
Expand Down Expand Up @@ -172,9 +180,12 @@ def perform_request(
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
# see https://github.com/elastic/elasticsearch/issues/51816
if self.distribution_version is not None and self.distribution_version >= versions.Version.from_string("8.0.0"):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)
if not self.is_serverless:
if versions.is_version_identifier(self.distribution_version) and (
versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0")
):
_mimetype_header_to_compat("Accept", headers)
_mimetype_header_to_compat("Content-Type", headers)

if params:
target = f"{path}?{_quote_query(params)}"
Expand Down
25 changes: 17 additions & 8 deletions esrally/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,21 @@ def _create_track_preparator(self, host):

def _after_track_prepared(self):
cluster_version = self.cluster_details["version"] if self.cluster_details else {}
# manually compiled versions don't expose build_flavor but Rally expects a value in telemetry devices
# we should default to trial/basic, but let's default to oss for now to avoid breaking the charts
build_flavor = cluster_version.get("build_flavor", "oss")
build_version = cluster_version.get("number", build_flavor)
build_hash = cluster_version.get("build_hash", build_flavor)

for child in self.children:
self.send(child, thespian.actors.ActorExitRequest())
self.children = []
self.send(
self.start_sender,
PreparationComplete(
# manually compiled versions don't expose build_flavor but Rally expects a value in telemetry devices
# we should default to trial/basic, but let's default to oss for now to avoid breaking the charts
cluster_version.get("build_flavor", "oss"),
cluster_version.get("number"),
cluster_version.get("build_hash"),
build_flavor,
build_version,
build_hash,
),
)

Expand Down Expand Up @@ -599,14 +603,15 @@ def __init__(self, target, config, es_client_factory_class=client.EsClientFactor
def create_es_clients(self):
all_hosts = self.config.opts("client", "hosts").all_hosts
distribution_version = self.config.opts("mechanic", "distribution.version", mandatory=False)
distribution_flavor = self.config.opts("mechanic", "distribution.flavor", mandatory=False)
es = {}
for cluster_name, cluster_hosts in all_hosts.items():
all_client_options = self.config.opts("client", "options").all_client_options
cluster_client_options = dict(all_client_options[cluster_name])
# Use retries to avoid aborts on long living connections for telemetry devices
cluster_client_options["retry_on_timeout"] = True
es[cluster_name] = self.es_client_factory(
cluster_hosts, cluster_client_options, distribution_version=distribution_version
cluster_hosts, cluster_client_options, distribution_version=distribution_version, distribution_flavor=distribution_flavor
).create()
return es

Expand Down Expand Up @@ -1729,13 +1734,16 @@ def _logging_exception_handler(self, loop, context):
self.logger.error("Uncaught exception in event loop: %s", context)

async def run(self):
def es_clients(client_id, all_hosts, all_client_options, distribution_version):
def es_clients(client_id, all_hosts, all_client_options, distribution_version, distribution_flavor):
es = {}
context = self.client_contexts.get(client_id)
api_key = context.api_key
for cluster_name, cluster_hosts in all_hosts.items():
es[cluster_name] = client.EsClientFactory(
cluster_hosts, all_client_options[cluster_name], distribution_version=distribution_version
cluster_hosts,
all_client_options[cluster_name],
distribution_version=distribution_version,
distribution_flavor=distribution_flavor,
).create_async(api_key=api_key, client_id=client_id)
return es

Expand All @@ -1758,6 +1766,7 @@ def es_clients(client_id, all_hosts, all_client_options, distribution_version):
self.cfg.opts("client", "hosts").all_hosts,
self.cfg.opts("client", "options"),
self.cfg.opts("mechanic", "distribution.version", mandatory=False),
self.cfg.opts("mechanic", "distribution.flavor", mandatory=False),
)
clients.append(es)
async_executor = AsyncExecutor(
Expand Down
5 changes: 3 additions & 2 deletions esrally/driver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,12 +2094,13 @@ async def __call__(self, es, params):
repository = mandatory(params, "repository", repr(self))
wait_period = params.get("completion-recheck-wait-period", 1)
es_info = await es.info()
es_version = Version.from_string(es_info["version"]["number"])
es_version = es_info["version"].get("number", "8.3.0")

request_args = {"repository": repository, "snapshot": "_current", "verbose": False}

# significantly reduce response size when lots of snapshots have been taken
# only available since ES 8.3.0 (https://github.com/elastic/elasticsearch/pull/86269)
if (es_version.major, es_version.minor) >= (8, 3):
if (Version.from_string(es_version) >= Version.from_string("8.3.0")) or es.is_serverless:
request_args["index_names"] = False

while True:
Expand Down
1 change: 0 additions & 1 deletion esrally/mechanic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
StartEngine,
StopEngine,
build,
cluster_distribution_version,
download,
install,
start,
Expand Down
19 changes: 1 addition & 18 deletions esrally/mechanic/mechanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import thespian.actors

from esrally import PROGRAM_NAME, actor, client, config, exceptions, metrics, paths
from esrally import PROGRAM_NAME, actor, config, exceptions, metrics, paths
from esrally.mechanic import launcher, provisioner, supplier, team
from esrally.utils import console, net

Expand Down Expand Up @@ -271,23 +271,6 @@ class NodesStopped:
pass


def cluster_distribution_version(cfg, client_factory=client.EsClientFactory):
"""
Attempt to get the cluster's distribution version even before it is actually started (which makes only sense for externally
provisioned clusters).

:param cfg: The current config object.
:param client_factory: Factory class that creates the Elasticsearch client.
:return: The distribution version.
"""
hosts = cfg.opts("client", "hosts").default
client_options = cfg.opts("client", "options").default
es = client_factory(hosts, client_options).create()
# unconditionally wait for the REST layer - if it's not up by then, we'll intentionally raise the original error
client.wait_for_rest_layer(es)
return es.info()["version"]["number"]


def to_ip_port(hosts):
ip_port_pairs = []
for host in hosts:
Expand Down
Loading