Skip to content

Commit

Permalink
Refactor _get_model_path. Add support for top-level loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Aug 29, 2024
1 parent 6ad185c commit 36088b4
Showing 1 changed file with 105 additions and 90 deletions.
195 changes: 105 additions & 90 deletions python/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,24 @@ def load_module(name, path, persist=True):
re.VERBOSE | re.IGNORECASE,
)

_UNVERSIONED_MODEL_PATTERN = (
r"""
(?P<model_name>
(?:
[a-zA-Z0-9]+
)
|
(?:
(?:[a-zA-Z0-9]+(?:[-_\.][a-zA-Z0-9]+)*(?:[-_\.][a-zA-Z]+[a-zA-Z0-9]*))?
)
)
(?:
-
(?P<version>"""
+ _VERSION_PATTERN
+ r"))?"
)

_MODEL_PATTERN = (
r"""
(?P<model_name>
Expand Down Expand Up @@ -418,128 +436,123 @@ def tokenize_version(version):
return tuple(token_list)


def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exist=True):
# Get the path to the model file
_model_regex = re.compile(
r"^\s*" + _MODEL_PATTERN + ("" if suffix is None else r"(?:" + suffix + r")?") + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
if suffix is None:
suffix = ""
# Get the path to the resources directory
resources_dir = resource_package_dir()
def _get_base_directory(resources_dir, prefix):
base_dir = resources_dir

# Add prefix if present
if prefix is not None:
base_dir = os.path.join(base_dir, prefix)
return base_dir

# Get the model name and version
d = _model_regex.match(model_name)
if d is None:
raise ValueError("Invalid model name: {}".format(model_name))
d = d.groupdict()
model_name = d["model_name"]
version = d["version"]

# Search for the model folder in the resources directory
def _find_model_folder_and_file(base_dir, model_name, must_exist, specific_file=None):
model_names = [
f for f in os.listdir(base_dir) if not os.path.isfile(os.path.join(base_dir, f))
]
model_names = [f for f in model_names if f.lower().startswith(model_name.lower())]

folder_exists = False
exact_model_names = [f for f in model_names if f.lower() == model_name.lower()]

if len(exact_model_names) == 0:
model_names = [f for f in model_names if f.lower().startswith(model_name.lower())]
else:
model_names = exact_model_names

if len(model_names) == 0 and must_exist:
# Whoops, we didn't find the model folder!
raise ValueError(
"No model folders found for {}\nSearched in ".format(model_name, base_dir)
)
raise ValueError(f"No model folders found for {model_name}\nSearched in {base_dir}")
elif len(model_names) == 0 and not must_exist:
# Let's use the provided model name as the folder name
model_name = model_name
return model_name, False, None
elif len(model_names) == 1:
# We found the model folder!
folder_exists = True
model_name = model_names[0]
name = model_names[0]
if specific_file is not None:
specific_file_path = os.path.join(base_dir, name, specific_file)
if os.path.isfile(specific_file_path):
return name, True, specific_file_path
else:
return name, True, None
else:
# Multiple model folders found, we cannot decide which one to use
raise ValueError(
"Multiple directories found for {}\nSearched in ".format(
model_name, base_dir
)
)
raise ValueError(f"Multiple directories found for {model_name}\nSearched in {base_dir}")

def _get_model_files(base_dir, model_name, is_file, folder_exists, version=None):
if folder_exists:
# Search for the model file in the model folder
model_files = [
f
for f in os.listdir(os.path.join(base_dir, model_name))
if version:
version_dir = os.path.join(base_dir, model_name, f"v{version}")
if os.path.isdir(version_dir):
return [
f for f in os.listdir(version_dir)
if is_file == os.path.isfile(os.path.join(version_dir, f))
]
return [
f for f in os.listdir(os.path.join(base_dir, model_name))
if is_file == os.path.isfile(os.path.join(base_dir, model_name, f))
]
else:
model_files = []
return []

# From the found model files, extract the model versions
def _extract_model_versions(model_files, model_regex, model_name):
model_versions = []
for f in model_files:
d = _model_regex.match(f)
d = model_regex.match(f)
if d is not None:
if d.groupdict()["version"] is not None:
model_versions.append(normalize_version(d.groupdict()["version"]))
else:
print(ValueError(
"Input model file has no version: {}\nSearched in ".format(
f, os.path.join(base_dir, model_name)
)
))
print(f"Warning: Input model file has no version: {f}")
elif f.lower().startswith(model_name.lower()):
print(ValueError(
"Unable to parse version from {}\nFound in ".format(
f, os.path.join(base_dir, model_name)
)
))

# Raise an error if no model file is found and we require it to exist
if len(model_versions) == 0 and must_exist:
raise ValueError(
"No model found for {}\nSearched in {}".format(
model_name, os.path.join(base_dir, model_name)
)
)
print(f"Warning: Unable to parse version from {f}")
return model_versions

def _get_model_file_name(version, model_versions, model_files, model_name, suffix, must_exist):
if version is None and must_exist:
# If no version is provided, use the latest version
version_idx, version = max(
enumerate(model_versions), key=lambda x: tokenize_version(x[1])
)
model_file_name = model_files[version_idx]
version_idx, version = max(enumerate(model_versions), key=lambda x: tokenize_version(x[1]))
return model_files[version_idx]
elif version is None and not must_exist:
# If no version is provided and we don't require it to exist, default to v1
version = "v1"
model_file_name = "{}-v{}{}".format(model_name, version, suffix)
return f"{model_name}-v{version}{suffix}"
else:
# A version is provided
version = normalize_version(version)
if must_exist:
# If the version must exist, raise an error if it doesn't
if version not in model_versions:
raise ValueError(
"No model found for {}-{}\nSearched in ".format(
model_name, version, os.path.join(base_dir, model_name)
)
)
raise ValueError(f"No model found for {model_name}-{version}")
version_idx = model_versions.index(version)
model_file_name = model_files[version_idx]
return model_files[version_idx]
else:
# The version doesn't have to exist
if version in model_versions:
# If the version exists, use it
version_idx = model_versions.index(version)
model_file_name = model_files[version_idx]
return model_files[version_idx]
else:
# Otherwise use the provided version
model_file_name = "{}-v{}{}".format(model_name, version, suffix)
return f"{model_name}-v{version}{suffix}"

def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exist=True, specific_file=None):
_model_regex = re.compile(
r"^\s*" + _MODEL_PATTERN + ("" if suffix is None else r"(?:" + suffix + r")?") + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
suffix = "" if suffix is None else suffix

resources_dir = resource_package_dir()
base_dir = _get_base_directory(resources_dir, prefix)

d = _model_regex.match(model_name)
if d is None:
raise ValueError(f"Invalid model name: {model_name}")
d = d.groupdict()
model_name, version = d["model_name"], d["version"]

model_name, folder_exists, specific_file_path = _find_model_folder_and_file(base_dir, model_name, must_exist, specific_file)

if specific_file_path and not version:
return os.path.dirname(specific_file_path)

model_files = _get_model_files(base_dir, model_name, is_file, folder_exists, version)
model_versions = _extract_model_versions(model_files, _model_regex, model_name)

if len(model_versions) == 0 and must_exist:
if specific_file_path:
return os.path.dirname(specific_file_path)
raise ValueError(f"No model found for {model_name}\nSearched in {os.path.join(base_dir, model_name)}")

model_file_name = _get_model_file_name(version, model_versions, model_files, model_name, suffix, must_exist)

if version:
version_dir = os.path.join(base_dir, model_name, f"v{version}")
if os.path.isdir(version_dir):
return os.path.join(version_dir, model_file_name)

return os.path.join(base_dir, model_name, model_file_name)

Expand All @@ -560,26 +573,28 @@ def get_material_model_file_path(model_name, must_exist=True):


def get_flux_model_path(model_name, must_exist=True):
return _get_model_path(model_name, prefix=_resource_folder_by_name["flux"], is_file=False, must_exist=must_exist)
return _get_model_path(model_name, prefix=_resource_folder_by_name["flux"], is_file=False, must_exist=must_exist, specific_file=f"flux.py")


def get_detector_model_path(model_name, must_exist=True):
return _get_model_path(model_name, prefix=_resource_folder_by_name["detector"], is_file=False, must_exist=must_exist)
return _get_model_path(model_name, prefix=_resource_folder_by_name["detector"], is_file=False, must_exist=must_exist, specific_file=f"detector.py")


def get_processes_model_path(model_name, must_exist=True):
return _get_model_path(model_name, prefix=_resource_folder_by_name["processes"], is_file=False, must_exist=must_exist)
return _get_model_path(model_name, prefix=_resource_folder_by_name["processes"], is_file=False, must_exist=must_exist, specific_file="processes.py")


def load_resource(resource_type, resource_name, *args, **kwargs):
folder = _resource_folder_by_name[resource_type]
specific_file = f"{resource_type}.py"

abs_dir = _get_model_path(resource_name, prefix=folder, is_file=False, must_exist=True)
abs_dir = _get_model_path(resource_name, prefix=folder, is_file=False, must_exist=True, specific_file=specific_file)

fname = os.path.join(abs_dir, f"{resource_name}.py")
fname = os.path.join(abs_dir, f"{resource_type}.py")
print(fname)
assert(os.path.isfile(fname))
resource_module = load_module(f"siren-{resource_type}-{resource_name}", fname, persist=False)
loader = getattr(resource_module, f"load_{resource_name}")
loader = getattr(resource_module, f"load_{resource_type}")
resource = loader(*args, **kwargs)
return resource

Expand Down

0 comments on commit 36088b4

Please sign in to comment.