Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor version number checks #1738

Merged
merged 14 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions esrally/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .context import RequestContextHolder, RequestContextManager
from .factory import (
EsClientFactory,
cluster_distribution_version,
create_api_key,
delete_api_keys,
wait_for_rest_layer,
Expand Down
35 changes: 31 additions & 4 deletions esrally/client/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,13 @@ def __init__(self, *args, **kwargs):
# The async client is used in the hot code path and we use customized overrides (such as that we don't
# parse response bodies in some cases for performance reasons, e.g. when using the bulk API).
self._verified_elasticsearch = True
self._serverless = False

# this isn't always available because any call to self.options() doesn't pass any custom args
# to the constructor
# https://github.com/elastic/rally/issues/1673
if distribution_version:
self.distribution_version = versions.Version.from_string(distribution_version)
self.distribution_version = distribution_version
else:
self.distribution_version = None
pquentin marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -300,6 +305,10 @@ def __init__(self, *args, **kwargs):
# runners that aren't using the new kwargs
self.ilm = RallyIlmClient(self)

@property
def is_serverless(self):
return self._serverless

async def perform_request(
self,
method: str,
Expand Down Expand Up @@ -328,9 +337,13 @@ async def perform_request(
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
# see https://github.com/elastic/elasticsearch/issues/51816
if self.distribution_version is not None and self.distribution_version >= versions.Version.from_string("8.0.0"):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)
# Not applicable to serverless
if not self.is_serverless:
if self.distribution_version is not None and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assume that distribution_version will be defined now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can for the async client, just because we currently only create them in one place where we already know the distribution flavor/version. I mainly did this to be consistent between the two, but I think a better approach is to just check whether the value is a valid identifier as I've done in 5f2db54

versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0") and not self.is_serverless
pquentin marked this conversation as resolved.
Show resolved Hide resolved
):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)

if params:
target = f"{path}?{_quote_query(params)}"
Expand Down Expand Up @@ -399,3 +412,17 @@ async def perform_request(
response = ApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]

return response


class RallyAsyncElasticsearchServerless(RallyAsyncElasticsearch):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# we set this as an instance attribute because we can't afford to make any calls to external APIs to verify
# whether we're talking to a serverless cluster or not once we're executing the benchmark.
#
# the reason for this is because the client can reinstantiate itself (e.g. with a call to .options()) during
# the execution of the benchmark, which means external API calls add unnecessary latency, and it reinstantiates
# itself without the ability to pass any custom arguments (i.e. distribution_version) to the constructor
#
# see https://github.com/elastic/rally/issues/1673
self._serverless = True
99 changes: 73 additions & 26 deletions esrally/client/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EsClientFactory:
compatibility guarantees that are broader than the library's defaults.
"""

def __init__(self, hosts, client_options, distribution_version=None):
def __init__(self, hosts, client_options, distribution_version=None, distribution_flavor=None):
pquentin marked this conversation as resolved.
Show resolved Hide resolved
def host_string(host):
# protocol can be set at either host or client opts level
protocol = "https" if client_options.get("use_ssl") or host.get("use_ssl") else "http"
Expand All @@ -43,6 +43,7 @@ def host_string(host):
# This attribute is necessary for the backwards-compatibility logic contained in
# RallySyncElasticsearch.perform_request() and RallyAsyncElasticsearch.perform_request().
self.distribution_version = distribution_version
self.distribution_flavor = distribution_flavor
self.logger = logging.getLogger(__name__)

masked_client_options = dict(client_options)
Expand Down Expand Up @@ -178,12 +179,20 @@ def _is_set(self, client_opts, k):

def create(self):
# pylint: disable=import-outside-toplevel
from esrally.client.synchronous import RallySyncElasticsearch

return RallySyncElasticsearch(
distribution_version=self.distribution_version, hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options
from esrally.client.synchronous import (
RallySyncElasticsearch,
RallySyncElasticsearchServerless,
)

if versions.is_serverless(self.distribution_flavor):
return RallySyncElasticsearchServerless(
distribution_version=self.distribution_version, hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options
)
else:
return RallySyncElasticsearch(
distribution_version=self.distribution_version, hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options
)

def create_async(self, api_key=None, client_id=None):
# pylint: disable=import-outside-toplevel
import io
Expand All @@ -193,6 +202,7 @@ def create_async(self, api_key=None, client_id=None):

from esrally.client.asynchronous import (
RallyAsyncElasticsearch,
RallyAsyncElasticsearchServerless,
RallyAsyncTransport,
)

Expand All @@ -204,35 +214,45 @@ def loads(self, data):
else:
return super().loads(data)

# override the builtin JSON serializer
self.client_options["serializer"] = LazyJSONSerializer()

if api_key is not None:
self.client_options.pop("http_auth", None)
self.client_options.pop("basic_auth", None)
self.client_options["api_key"] = api_key

if versions.is_serverless(self.distribution_flavor):
async_client = RallyAsyncElasticsearchServerless(
distribution_version=self.distribution_version,
hosts=self.hosts,
transport_class=RallyAsyncTransport,
ssl_context=self.ssl_context,
maxsize=self.max_connections,
**self.client_options,
)
else:
async_client = RallyAsyncElasticsearch(
distribution_version=self.distribution_version,
hosts=self.hosts,
transport_class=RallyAsyncTransport,
ssl_context=self.ssl_context,
maxsize=self.max_connections,
**self.client_options,
)

async def on_request_start(session, trace_config_ctx, params):
RallyAsyncElasticsearch.on_request_start()
async_client.on_request_start()

async def on_request_end(session, trace_config_ctx, params):
RallyAsyncElasticsearch.on_request_end()
async_client.on_request_end()

trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_request_start)
trace_config.on_request_end.append(on_request_end)
# ensure that we also stop the timer when a request "ends" with an exception (e.g. a timeout)
trace_config.on_request_exception.append(on_request_end)

# override the builtin JSON serializer
self.client_options["serializer"] = LazyJSONSerializer()

if api_key is not None:
self.client_options.pop("http_auth", None)
self.client_options.pop("basic_auth", None)
self.client_options["api_key"] = api_key

async_client = RallyAsyncElasticsearch(
distribution_version=self.distribution_version,
hosts=self.hosts,
transport_class=RallyAsyncTransport,
ssl_context=self.ssl_context,
maxsize=self.max_connections,
**self.client_options,
)

# the AsyncElasticsearch constructor automatically creates the corresponding NodeConfig objects, so we set
# their instance attributes after they've been instantiated
for node_connection in async_client.transport.node_pool.all():
Expand Down Expand Up @@ -316,6 +336,32 @@ def wait_for_rest_layer(es, max_attempts=40):
return False


def cluster_distribution_version(hosts, client_options, client_factory=EsClientFactory):
"""
Attempt to get the target cluster's distribution version, build flavor, and build hash by creating and using
a 'sync' Elasticsearch client.

:param hosts: The host(s) to connect to.
:param client_options: The client options to customize the Elasticsearch client.
:param client_factory: Factory class that creates the Elasticsearch client.
:return: The cluster's build flavor, version number, and build hash. For Serverless Elasticsearch these may all be
the build flavor value.
"""
# no way for us to know whether we're talking to a serverless elasticsearch or not, so we default to the sync client
es = client_factory(hosts, client_options).create()
# unconditionally wait for the REST layer - if it's not up by then, we'll intentionally raise the original error
wait_for_rest_layer(es)
version = es.info()["version"]

version_build_flavor = version.get("build_flavor", "oss")
# build hash will only be available for serverless if the client has operator privs
version_build_hash = version.get("build_hash", version_build_flavor)
# version number does not exist for serverless
version_number = version.get("number", version_build_flavor)

return version_build_flavor, version_number, version_build_hash


def create_api_key(es, client_id, max_attempts=5):
"""
Creates an API key for the provided ``client_id``.
Expand Down Expand Up @@ -366,7 +412,8 @@ def raise_exception(failed_ids, cause=None):

# Before ES 7.10, deleting API keys by ID had to be done individually.
# After ES 7.10, a list of API key IDs can be deleted in one request.
current_version = versions.Version.from_string(es.info()["version"]["number"])
version = es.info()["version"]
current_version = versions.Version.from_string(version.get("number", "7.10.0"))
minimum_version = versions.Version.from_string("7.10.0")

deleted = []
Expand All @@ -377,7 +424,7 @@ def raise_exception(failed_ids, cause=None):
import elasticsearch

try:
if current_version >= minimum_version:
if current_version >= minimum_version or es.is_serverless:
resp = es.security.invalidate_api_key(ids=remaining)
deleted += resp["invalidated_api_keys"]
remaining = [i for i in ids if i not in deleted]
Expand Down
67 changes: 46 additions & 21 deletions esrally/client/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import re
import warnings
from typing import Any, Iterable, Mapping, Optional

Expand Down Expand Up @@ -76,44 +75,45 @@ def raise_error(cls, state, meta, body):
@classmethod
def check_product(cls, headers, response):
# type: (dict[str, str], dict[str, str]) -> int
"""Verifies that the server we're talking to is Elasticsearch.
"""
Verifies that the server we're talking to is Elasticsearch.
Does this by checking HTTP headers and the deserialized
response to the 'info' API. Returns one of the states above.
"""

version = response.get("version", {})
try:
version = response.get("version", {})
version_number = tuple(
int(x) if x is not None else 999 for x in re.search(r"^([0-9]+)\.([0-9]+)(?:\.([0-9]+))?", version["number"]).groups()
)
except (KeyError, TypeError, ValueError, AttributeError):
# No valid 'version.number' field, effectively 0.0.0
version = {}
version_number = (0, 0, 0)
version_number = versions.Version.from_string(version.get("number", None))
except TypeError:
# No valid 'version.number' field, either Serverless Elasticsearch, or not Elasticsearch at all
version_number = versions.Version.from_string("0.0.0")

build_flavor = version.get("build_flavor", None)

# Check all of the fields and headers for missing/valid values.
try:
bad_tagline = response.get("tagline", None) != "You Know, for Search"
bad_build_flavor = version.get("build_flavor", None) != "default"
bad_build_flavor = build_flavor not in ("default", "serverless")
bad_product_header = headers.get("x-elastic-product", None) != "Elasticsearch"
except (AttributeError, TypeError):
bad_tagline = True
bad_build_flavor = True
bad_product_header = True

# 7.0-7.13 and there's a bad 'tagline' or unsupported 'build_flavor'
if (7, 0, 0) <= version_number < (7, 14, 0):
if versions.Version.from_string("7.0.0") <= version_number < versions.Version.from_string("7.14.0"):
if bad_tagline:
return cls.UNSUPPORTED_PRODUCT
elif bad_build_flavor:
return cls.UNSUPPORTED_DISTRIBUTION

elif (
# No version or version less than 6.x
version_number < (6, 0, 0)
# 6.x and there's a bad 'tagline'
or ((6, 0, 0) <= version_number < (7, 0, 0) and bad_tagline)
# No version or version less than 6.8.0, and we're not talking to a serverless elasticsearch
(version_number < versions.Version.from_string("6.8.0") and not versions.is_serverless(build_flavor))
# 6.8.0 and there's a bad 'tagline'
or (versions.Version.from_string("6.8.0") <= version_number < versions.Version.from_string("7.0.0") and bad_tagline)
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
or ((7, 14, 0) <= version_number and bad_product_header)
or (versions.Version.from_string("7.14.0") <= version_number and bad_product_header)
):
return cls.UNSUPPORTED_PRODUCT

Expand All @@ -125,12 +125,20 @@ def __init__(self, *args, **kwargs):
distribution_version = kwargs.pop("distribution_version", None)
super().__init__(*args, **kwargs)
self._verified_elasticsearch = None
self._serverless = False

# this isn't always available because any call to self.options() doesn't pass any custom args
# to the constructor
# https://github.com/elastic/rally/issues/1673
if distribution_version:
self.distribution_version = versions.Version.from_string(distribution_version)
self.distribution_version = distribution_version
else:
self.distribution_version = None

@property
def is_serverless(self):
return self._serverless

def perform_request(
self,
method: str,
Expand Down Expand Up @@ -172,9 +180,12 @@ def perform_request(
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
# see https://github.com/elastic/elasticsearch/issues/51816
if self.distribution_version is not None and self.distribution_version >= versions.Version.from_string("8.0.0"):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)
if not self.is_serverless:
if self.distribution_version is not None and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 5f2db54

versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0")
):
_mimetype_header_to_compat("Accept", headers)
_mimetype_header_to_compat("Content-Type", headers)

if params:
target = f"{path}?{_quote_query(params)}"
Expand Down Expand Up @@ -243,3 +254,17 @@ def perform_request(
response = ApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]

return response


class RallySyncElasticsearchServerless(RallySyncElasticsearch):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# we set this as an instance attribute because we can't afford to make any calls to external APIs to verify
# whether we're talking to a serverless cluster or not once we're executing the benchmark.
#
# the reason for this is because the client can reinstantiate itself (e.g. with a call to .options()) during
# the execution of the benchmark, which means external API calls add unnecessary latency, and it reinstantiates
# itself without the ability to pass any custom arguments (i.e. distribution_version) to the constructor
#
# see https://github.com/elastic/rally/issues/1673
self._serverless = True
Loading