Skip to content

Commit

Permalink
evaluate the $ref early during the loading
Browse files Browse the repository at this point in the history
Evaluate and get ride of all the $ref just after the loading of
the JSON definition. This way we don't have to deal with the
references and definitions later during the processing.
  • Loading branch information
goneri committed Dec 20, 2022
1 parent f0c0499 commit b954271
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 327 deletions.
179 changes: 76 additions & 103 deletions vmware_rest_code_generator/cmd/refresh_modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import argparse
from typing import DefaultDict
from typing import DefaultDict, Dict, Any
import jinja2
import json
import os
Expand Down Expand Up @@ -387,35 +387,6 @@ def gen_arguments_py(parameters, list_index=None):
return result


def flatten_ref(tree, definitions):
if isinstance(tree, str):
if tree.startswith("#/definitions/"):
raise Exception("TODO")
return definitions.get(tree)
if isinstance(tree, list):
return [flatten_ref(i, definitions) for i in tree]
if tree is None:
return {}
for k in tree:
v = tree[k]
if k == "$ref":
dotted = v.split("/")[2]
if dotted in ["vapi.std.localization_param", "VapiStdLocalizationParam"]:
# to avoid an endless loop with
# vapi.std.nested_localizable_message
return {"go_to": "vapi.std.localization_param"}
definition = definitions.get(dotted)
data = flatten_ref(definition, definitions)
if "description" not in data and "description" in tree:
data["description"] = tree["description"]
return data
elif isinstance(v, dict):
tree[k] = flatten_ref(v, definitions)
else:
pass
return tree


class Resource:
def __init__(self, name):
self.name = name
Expand All @@ -424,9 +395,8 @@ def __init__(self, name):


class AnsibleModuleBase(UtilsBase):
def __init__(self, resource, definitions):
def __init__(self, resource):
self.resource = resource
self.definitions = definitions
self.name = resource.name
self.default_operationIds = None

Expand Down Expand Up @@ -471,41 +441,21 @@ def payload(self):
payload[operationId] = {"query": {}, "body": {}, "path": {}}
payload_info = {}
for parameter in AnsibleModule._property_to_parameter(
self.resource.operations[operationId][2], self.definitions, operationId
self.resource.operations[operationId][2], operationId
):
_in = parameter["in"] or "body"

payload_info = parameter["_loc_in_payload"]
payload[operationId][_in][parameter["name"]] = payload_info
return payload

def answer(self):
# This is arguably not super elegant. The list outputs just include a summary of the resources,
# with this little transformation, we get access to the full item
output_format = None
for i in ["list", "get"]:
def answer(self) -> Dict[str, Dict[Any, Any]]:
raw_answer: Dict = {}
for i in ["get", "list"]:
if i in self.resource.operations:
output_format = self.resource.operations[i][3]["200"]
if not output_format:
return
raw_answer = self.resource.operations[i][3]["200"]
break

if "items" in output_format["schema"]:
ref = (
output_format["schema"]["items"]
.get("$ref", "")
.replace("Summary", "Info")
)
elif "schema" in output_format:
ref = output_format["schema"].get("$ref")
else:
ref = output_format.get("$ref")

if not ref:
return
try:
raw_answer = flatten_ref({"$ref": ref}, self.definitions)
except KeyError:
return
if "properties" in raw_answer:
return raw_answer["properties"].keys()

Expand All @@ -522,7 +472,7 @@ def sort_operationsid(input):
continue

for parameter in AnsibleModule._property_to_parameter(
self.resource.operations[operationId][2], self.definitions, operationId
self.resource.operations[operationId][2], operationId
):
name = parameter["name"]
if name not in results:
Expand Down Expand Up @@ -652,13 +602,10 @@ def gen_required_if(self, parameters):
return entries

@staticmethod
def _property_to_parameter(prop_struct, definitions, operationId):
properties = flatten_ref(prop_struct, definitions)

def get_next(properties):
def _property_to_parameter(properties, operationId):
def get_next():
required_keys = []
for i, v in enumerate(properties):
required = v.get("required")
if "schema" in v:
if "properties" in v["schema"]:
properties[i] = v["schema"]["properties"]
Expand All @@ -682,8 +629,6 @@ def get_next(properties):
yield name, property, ["spec"], name in required_keys

elif isinstance(v, dict):
if not isinstance(v, dict):
continue
# {'type': 'string', 'required': True, 'in': 'path', 'name': 'datacenter', 'description': 'Identifier of the datacenter.'}
if "name" in v and "in" in v and v.get("in") in ["path", "query"]:
yield v["name"], v, [], v.get("required")
Expand All @@ -698,7 +643,7 @@ def get_next(properties):

parameters = []

for name, v, parent, required in get_next(properties):
for name, v, parent, required in get_next():
if name == "request_body":
raise ValueError()
parameter = {
Expand Down Expand Up @@ -795,17 +740,17 @@ def renderer(self, target_dir, next_version):
class AnsibleModule(AnsibleModuleBase):
template_file = "default_module.j2"

def __init__(self, resource, definitions):
super().__init__(resource, definitions)
def __init__(self, resource):
super().__init__(resource)
# TODO: We can probably do better
self.default_operationIds = set(list(self.resource.operations.keys())) - set(
["get", "list"]
)


class AnsibleInfoModule(AnsibleModuleBase):
def __init__(self, resource, definitions):
super().__init__(resource, definitions)
def __init__(self, resource):
super().__init__(resource)
self.name = resource.name + "_info"
self.default_operationIds = ["get", "list"]

Expand All @@ -821,29 +766,6 @@ class AnsibleInfoListOnlyModule(AnsibleInfoModule):
template_file = "info_list_and_get_module.j2"


class Definitions:
def __init__(self, data):
super().__init__()
self.definitions = data

def get(self, ref):
if isinstance(ref, dict):
# TODO: standardize the input to avoid this step
dotted = ref["$ref"].split("/")[2]
else:
dotted = ref

try:
definition = self.definitions[dotted]
except KeyError:
definition = self.definitions["com.vmware." + dotted]

if definition is None:
raise Exception("Cannot find ref for {ref}")

return definition


class Path:
def __init__(self, path, value):
super().__init__()
Expand All @@ -862,12 +784,67 @@ def is_tech_preview(self):
return False


from collections.abc import ValuesView


def expand_ref(raw):
def lookup(name):
# Swagger2 vSphere 7.0.3
if name.startswith("#/definitions/vapi.std"):
return {}
if name.startswith("#/definitions/VapiStd"):
return {}

# OpenAPI3 vSphere 8.0.0
if name.startswith("#/components/schemas/vapi.std."):
return {}

if name.startswith("#/components/schemas/VapiStd"):
return {}

if name.startswith("#/"): # Local lookup
loc = raw
try:
for i in name.split("/")[1:]:
loc = loc[i]
except KeyError:
print(f"Cannot look up {name}")
return {}
return loc
else:
raise NotImplementedError

def walk(cur_v):
if isinstance(cur_v, ValuesView):
for a in cur_v:
walk(a)
elif isinstance(cur_v, dict):
ref = cur_v.pop("$ref", None)
if ref:
for k, v in lookup(ref).items():
if k in ["summary", "description"] and k in cur_v:
# We keep the original value (c.f: 4.8.23.1 Fixed Fields)
pass
else:
cur_v[k] = v
walk(cur_v)
walk(cur_v.values())
elif isinstance(cur_v, list):
for a in cur_v:
walk(a)
else:
pass

walk(raw["paths"].values())


class SwaggerFile:
def __init__(self, raw_content):
super().__init__()
self.resources = {}
json_content = json.loads(raw_content)
self.definitions = Definitions(json_content["definitions"])

expand_ref(json_content)
self.paths = self.load_paths(json_content["paths"])

@staticmethod
Expand All @@ -879,7 +856,7 @@ def load_paths(paths):
continue
result[path.path] = path
for verb, desc in path.value.items():
operationId = desc["operationId"]
operationId = desc.get("operationId", "verb")
if desc.get("deprecated"):
continue
try:
Expand Down Expand Up @@ -1025,9 +1002,7 @@ def main():
if resource.name.startswith("vcenter_trustedinfrastructure"):
continue
if "list" in resource.operations:
module = AnsibleInfoListOnlyModule(
resource, definitions=swagger_file.definitions
)
module = AnsibleInfoListOnlyModule(resource)
if (
module.is_trusted("vmware_rest_code_generator")
and len(module.default_operationIds) > 0
Expand All @@ -1037,9 +1012,7 @@ def main():
)
module_list.append(module.name)
elif "get" in resource.operations:
module = AnsibleInfoNoListModule(
resource, definitions=swagger_file.definitions
)
module = AnsibleInfoNoListModule(resource)
if (
module.is_trusted("vmware_rest_code_generator")
and len(module.default_operationIds) > 0
Expand All @@ -1049,7 +1022,7 @@ def main():
)
module_list.append(module.name)

module = AnsibleModule(resource, definitions=swagger_file.definitions)
module = AnsibleModule(resource)

if (
module.is_trusted("vmware_rest_code_generator")
Expand Down
Loading

0 comments on commit b954271

Please sign in to comment.