Skip to content

Commit

Permalink
feat: Improved readability and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
1101-1 committed Dec 4, 2023
1 parent 1b88c5e commit 11e22f8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 34 deletions.
14 changes: 5 additions & 9 deletions plugins/azure/resoto_plugin_azure/azure_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,12 @@ def _call(self, spec: AzureApiSpec, **kwargs: Any) -> List[Json]:
params = case_insensitive_dict(kwargs.pop("params", {}) or {})
params["api-version"] = _SERIALIZER.query("api_version", spec.version, "str") # type: ignore

# Define url map
format_map_paths = {"subscriptionId": self.subscription_id, "location": self.location, **params}
format_map_paths.update({param: kwargs.pop(param, "") for param in spec.path_parameters if param in kwargs})

# Construct url
path = spec.path.format_map(
{
"subscriptionId": self.subscription_id,
"location": self.location,
"resourceGroupName": kwargs.pop("resourceGroupName", ""),
"virtualNetworkName": kwargs.pop("virtualNetworkName", ""),
**params,
}
)
path = spec.path.format_map(format_map_paths)
url = self.client._client.format_url(path) # pylint: disable=protected-access

# Construct and send request
Expand Down
28 changes: 17 additions & 11 deletions plugins/azure/resoto_plugin_azure/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class AzureResource(BaseResource):
# Which API to call and what to expect in the result.
api_spec: ClassVar[Optional[AzureApiSpec]] = None

def resource_subscription_id(self) -> str:
return self.extract_part("subscriptionId")

def resource_resource_group_name(self) -> str:
return self.extract_part("resourceGroupName")

def extract_part(self, part: str) -> str:
"""
Extracts a specific part from a resource ID.
Expand All @@ -66,18 +72,18 @@ def extract_part(self, part: str) -> str:
id_parts = self.id.split("/")

if part == "subscriptionId":
if "subscriptions" in id_parts:
index = id_parts.index("subscriptions")
return id_parts[index + 1]
else:
if "subscriptions" not in id_parts:
raise ValueError(f"Id {self.id} does not have any subscriptionId info")
if index := id_parts.index("subscriptions"):
return id_parts[index + 1]
return ""

elif part == "resourceGroupName":
if "resourceGroups" in id_parts:
index = id_parts.index("resourceGroups")
return id_parts[index + 1]
else:
if "resourceGroups" not in id_parts:
raise ValueError(f"Id {self.id} does not have any resourceGroupName info")
if index := id_parts.index("resourceGroups"):
return id_parts[index + 1]
return ""

else:
raise ValueError(f"Value {part} does not have any cases to match")
Expand All @@ -89,7 +95,7 @@ def delete(self, graph: Graph) -> bool:
Returns:
bool: True if the resource was successfully deleted; False otherwise.
"""
subscription_id = self.extract_part("subscriptionId")
subscription_id = self.resource_subscription_id()
return get_client(subscription_id).delete(self.id)

def delete_tag(self, key: str) -> bool:
Expand All @@ -98,7 +104,7 @@ def delete_tag(self, key: str) -> bool:
This method removes a specific value from a tag associated with a subscription, while keeping the tag itself intact.
The tag remains on the account, but the specified value will be deleted.
"""
subscription_id = self.extract_part("subscriptionId")
subscription_id = self.resource_subscription_id()
return get_client(subscription_id).delete_resource_tag(tag_name=key, resource_id=self.id)

def update_tag(self, key: str, value: str) -> bool:
Expand All @@ -107,7 +113,7 @@ def update_tag(self, key: str, value: str) -> bool:
This method allows for the creation or update of a tag value associated with the specified tag name.
The tag name must already exist for the operation to be successful.
"""
subscription_id = self.extract_part("subscriptionId")
subscription_id = self.resource_subscription_id()
return get_client(subscription_id).update_resource_tag(tag_name=key, tag_value=value, resource_id=self.id)

def pre_process(self, graph_builder: GraphBuilder, source: Json) -> None:
Expand Down
28 changes: 14 additions & 14 deletions plugins/azure/resoto_plugin_azure/resource/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from resotolib.types import Json


def extract_vn_id(subnet_id: str) -> str:
def extract_virtual_network_id(subnet_id: str) -> str:
"""
Extracts {virtualNetworkName} value from a subnet ID and create virtual network ID
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureSubnet, id=subnet_id
)
vn_id = extract_vn_id(subnet_id)
vn_id = extract_virtual_network_id(subnet_id)
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id
)
Expand Down Expand Up @@ -1501,7 +1501,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if ip_confs := self.firewall_ip_configurations:
for ip_conf in ip_confs:
if subnet := ip_conf.subnet:
vn_id = extract_vn_id(subnet)
vn_id = extract_virtual_network_id(subnet)
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id
)
Expand Down Expand Up @@ -2893,7 +2893,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if ip_confs := network_interface.interface_ip_configurations:
for ip_conf in ip_confs:
if subnet := ip_conf._subnet_id:
vn_id = extract_vn_id(subnet)
vn_id = extract_virtual_network_id(subnet)
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id
)
Expand Down Expand Up @@ -4300,7 +4300,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if ip_configurations := container.ip_configurations:
for ip_configuration in ip_configurations:
if subnet := ip_configuration._subnet_id:
vn_id = extract_vn_id(subnet)
vn_id = extract_virtual_network_id(subnet)
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id
)
Expand Down Expand Up @@ -4541,14 +4541,14 @@ class AzureNetworkWatcher(AzureResource):

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
locations_and_ids_in_vn = self.fetch_resources(
builder,
"network",
"2023-05-01",
"/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
["subscriptionId"],
["api-version"],
lambda r: r["location"],
lambda r: r["id"],
builder=builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=lambda r: r["location"],
binding_property=lambda r: r["id"],
)

if (nw_location := self.location) and (vns_info := locations_and_ids_in_vn):
Expand Down Expand Up @@ -5192,7 +5192,7 @@ def collect_subnets() -> None:
access_path="value",
expect_array=True,
)
resource_group_name = self.extract_part("resourceGroupName")
resource_group_name = self.resource_resource_group_name()
virtual_network_name = self.name if self.name else ""

items = graph_builder.client.list(
Expand Down

0 comments on commit 11e22f8

Please sign in to comment.