Skip to content

Commit

Permalink
fix(inventory): improve performance (#402)
Browse files Browse the repository at this point in the history
##### SUMMARY

Improve the performance of the inventory plugin by:
- Cache client requests
- Move servers `status` filtering to query params.
  • Loading branch information
jooola authored Nov 24, 2023
1 parent fb40a00 commit f85d8f4
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 84 deletions.
7 changes: 7 additions & 0 deletions examples/inventory.hcloud.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# You can list the hosts using:
# ansible-inventory --list -i examples/inventory.hcloud.yml --extra-vars=network_name=my-network

plugin: hetzner.hcloud.hcloud

network: "{{ network_name }}"
status: [running]
167 changes: 84 additions & 83 deletions plugins/inventory/hcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@
from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, Constructable
from ansible.utils.display import Display

from ..module_utils.client import HAS_DATEUTIL, HAS_REQUESTS
from ..module_utils.vendor import hcloud
from ..module_utils.client import (
Client,
ClientException,
client_check_required_lib,
client_get_by_name_or_id,
)
from ..module_utils.vendor.hcloud import APIException
from ..module_utils.vendor.hcloud.networks import Network
from ..module_utils.vendor.hcloud.servers import Server
from ..module_utils.version import version

Expand Down Expand Up @@ -196,13 +202,24 @@ class InventoryServer(TypedDict):
InventoryServer = dict


def first_ipv6_address(network: str) -> str:
"""
Return the first address for a ipv6 network.
:param network: IPv6 Network.
"""
return next(IPv6Network(network).hosts())


class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
NAME = "hetzner.hcloud.hcloud"

inventory: InventoryData
display: Display

client: hcloud.Client
client: Client

network: Network | None

def _configure_hcloud_client(self):
# If api_token_env is not the default, print a deprecation warning and load the
Expand Down Expand Up @@ -232,7 +249,7 @@ def _configure_hcloud_client(self):
# Resolve template string
api_token = self.templar.template(api_token)

self.client = hcloud.Client(
self.client = Client(
token=api_token,
api_endpoint=api_endpoint,
application_name="ansible-inventory",
Expand All @@ -242,61 +259,47 @@ def _configure_hcloud_client(self):
try:
# Ensure the api token is valid
self.client.locations.get_list()
except hcloud.APIException as exception:
except APIException as exception:
raise AnsibleError("Invalid Hetzner Cloud API Token.") from exception

def _get_servers(self):
if len(self.get_option("label_selector")) > 0:
self.servers = self.client.servers.get_all(label_selector=self.get_option("label_selector"))
else:
self.servers = self.client.servers.get_all()

def _filter_servers(self):
def _validate_options(self) -> None:
if self.get_option("network"):
network = self.templar.template(self.get_option("network"), fail_on_undefined=False) or self.get_option(
"network"
)
network_param: str = self.get_option("network")
network_param = self.templar.template(network_param)

try:
self.network = self.client.networks.get_by_name(network)
if self.network is None:
self.network = self.client.networks.get_by_id(network)
except hcloud.APIException:
raise AnsibleError("The given network is not found.")

tmp = []
for server in self.servers:
for server_private_network in server.private_net:
if server_private_network.network.id == self.network.id:
tmp.append(server)
self.servers = tmp
self.network = client_get_by_name_or_id(self.client, "networks", network_param)
except (ClientException, APIException) as exception:
raise AnsibleError(to_native(exception)) from exception

def _fetch_servers(self) -> list[Server]:
self._validate_options()

get_servers_params = {}
if self.get_option("label_selector"):
get_servers_params["label_selector"] = self.get_option("label_selector")

if self.get_option("status"):
get_servers_params["status"] = self.get_option("status")

servers = self.client.servers.get_all(**get_servers_params)

if self.get_option("network"):
servers = [s for s in servers if self.network.id in [p.network.id for p in s.private_net]]

if self.get_option("locations"):
tmp = []
for server in self.servers:
if server.datacenter.location.name in self.get_option("locations"):
tmp.append(server)
self.servers = tmp
locations: list[str] = self.get_option("locations")
servers = [s for s in servers if s.datacenter.location.name in locations]

if self.get_option("types"):
tmp = []
for server in self.servers:
if server.server_type.name in self.get_option("types"):
tmp.append(server)
self.servers = tmp
server_types: list[str] = self.get_option("types")
servers = [s for s in servers if s.server_type.name in server_types]

if self.get_option("images"):
tmp = []
for server in self.servers:
if server.image is not None and server.image.os_flavor in self.get_option("images"):
tmp.append(server)
self.servers = tmp
images: list[str] = self.get_option("images")
servers = [s for s in servers if s.image is not None and s.image.os_flavor in images]

if self.get_option("status"):
tmp = []
for server in self.servers:
if server.status in self.get_option("status"):
tmp.append(server)
self.servers = tmp
return servers

def _build_inventory_server(self, server: Server) -> InventoryServer:
server_dict: InventoryServer = {}
Expand All @@ -311,7 +314,7 @@ def _build_inventory_server(self, server: Server) -> InventoryServer:
server_dict["ipv4"] = to_native(server.public_net.ipv4.ip)

if server.public_net.ipv6:
server_dict["ipv6"] = to_native(self._first_ipv6_address(server.public_net.ipv6.ip))
server_dict["ipv6"] = to_native(first_ipv6_address(server.public_net.ipv6.ip))
server_dict["ipv6_network"] = to_native(server.public_net.ipv6.network)
server_dict["ipv6_network_mask"] = to_native(server.public_net.ipv6.network_mask)

Expand All @@ -320,10 +323,11 @@ def _build_inventory_server(self, server: Server) -> InventoryServer:
]

if self.get_option("network"):
for server_private_network in server.private_net:
for private_net in server.private_net:
# Set private_ipv4 if user filtered for one network
if server_private_network.network.id == self.network.id:
server_dict["private_ipv4"] = to_native(server_private_network.ip)
if private_net.network.id == self.network.id:
server_dict["private_ipv4"] = to_native(private_net.ip)
break

# Server Type
if server.server_type is not None:
Expand Down Expand Up @@ -353,60 +357,54 @@ def _build_inventory_server(self, server: Server) -> InventoryServer:

return server_dict

def _get_server_ansible_host(self, server):
def _get_server_ansible_host(self, server: Server):
if self.get_option("connect_with") == "public_ipv4":
if server.public_net.ipv4:
return to_native(server.public_net.ipv4.ip)
else:
raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified")
raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified")

if self.get_option("connect_with") == "public_ipv6":
if server.public_net.ipv6:
return to_native(self._first_ipv6_address(server.public_net.ipv6.ip))
else:
raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified")
return to_native(first_ipv6_address(server.public_net.ipv6.ip))
raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified")

elif self.get_option("connect_with") == "hostname":
if self.get_option("connect_with") == "hostname":
# every server has a name, no need to guard this
return to_native(server.name)

elif self.get_option("connect_with") == "ipv4_dns_ptr":
if self.get_option("connect_with") == "ipv4_dns_ptr":
if server.public_net.ipv4:
return to_native(server.public_net.ipv4.dns_ptr)
else:
raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified")
raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified")

elif self.get_option("connect_with") == "private_ipv4":
if self.get_option("connect_with") == "private_ipv4":
if self.get_option("network"):
for server_private_network in server.private_net:
if server_private_network.network.id == self.network.id:
return to_native(server_private_network.ip)
for private_net in server.private_net:
if private_net.network.id == self.network.id:
return to_native(private_net.ip)

else:
raise AnsibleError("You can only connect via private IPv4 if you specify a network")

def _first_ipv6_address(self, network):
return next(IPv6Network(network).hosts())

def verify_file(self, path):
"""Return the possibly of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith(("hcloud.yaml", "hcloud.yml"))

def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer | None], bool]:
def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer], bool]:
# false when refresh_cache or --flush-cache is used
if not cache:
return None, False
return [], False

# get the user-specified directive
if not self.get_option("cache"):
return None, False
return [], False

cache_key = self.get_cache_key(path)
try:
cached_result = self._cache[cache_key]
except KeyError:
# if cache expires or cache file doesn"t exist
return None, False
return [], False

return cached_result, True

Expand All @@ -426,24 +424,27 @@ def _update_cached_result(self, path, cache, result: list[InventoryServer]):
def parse(self, inventory, loader, path, cache=True):
super().parse(inventory, loader, path, cache)

if not HAS_REQUESTS:
raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires requests.")
if not HAS_DATEUTIL:
raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires python-dateutil.")
try:
client_check_required_lib()
except ClientException as exception:
raise AnsibleError(to_native(exception)) from exception

# Allow using extra variables arguments as template variables (e.g.
# '--extra-vars my_var=my_value')
self.templar.available_variables = self._vars

self._read_config_data(path)
self._configure_hcloud_client()

self.servers, cached = self._get_cached_result(path, cache)
servers, cached = self._get_cached_result(path, cache)
if not cached:
self._get_servers()
self._filter_servers()
self.servers = [self._build_inventory_server(server) for server in self.servers]
with self.client.cached_session():
servers = [self._build_inventory_server(s) for s in self._fetch_servers()]

# Add a top group
self.inventory.add_group(group=self.get_option("group"))

for server in self.servers:
for server in servers:
self.inventory.add_host(server["name"], group=self.get_option("group"))
for key, value in server.items():
self.inventory.set_variable(server["name"], key, value)
Expand Down Expand Up @@ -475,4 +476,4 @@ def parse(self, inventory, loader, path, cache=True):
strict=strict,
)

self._update_cached_result(path, cache, self.servers)
self._update_cached_result(path, cache, servers)
41 changes: 40 additions & 1 deletion plugins/module_utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from __future__ import annotations

from contextlib import contextmanager

from ansible.module_utils.basic import missing_required_lib

from .vendor.hcloud import APIException, Client
from .vendor.hcloud import APIException, Client as ClientBase

HAS_REQUESTS = True
HAS_DATEUTIL = True
Expand Down Expand Up @@ -61,3 +63,40 @@ def client_get_by_name_or_id(client: Client, resource: str, param: str | int):
if exception.code == "not_found":
raise _client_resource_not_found(resource, param) from exception
raise exception


if HAS_REQUESTS:

class CachedSession(requests.Session):
cache: dict[str, requests.Response] = {}

def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response: # type: ignore[no-untyped-def]
"""
Send a given PreparedRequest.
"""
if request.method != "GET" or request.url is None:
return super().send(request, **kwargs)

if request.url in self.cache:
return self.cache[request.url]

response = super().send(request, **kwargs)
if response.ok:
self.cache[request.url] = response

return response


class Client(ClientBase):
@contextmanager
def cached_session(self) -> None:
"""
Swap the client session during the scope of the context. The session will cache
all GET requests.
Cached response will not expire, therefore the cached client must not be used
for long living scopes.
"""
self._requests_session = CachedSession()
yield
self._requests_session = requests.Session()

0 comments on commit f85d8f4

Please sign in to comment.