Skip to content

Commit

Permalink
Expand supported types of recipes (#109) (#110)
Browse files Browse the repository at this point in the history
* Expand supported types of recipes

* fix for tests
  • Loading branch information
markurtz authored Oct 21, 2021
1 parent ff43990 commit 3c6091c
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/sparsezoo/objects/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class RecipeTypes(Enum):
"""

ORIGINAL = "original"
SPARSE = "sparse"
TRANSFER = "transfer"
TRANSFER_LEARN = "transfer_learn"


Expand Down Expand Up @@ -455,9 +457,6 @@ def search_sparse_recipes(
"""
from sparsezoo.objects.model import Model

if isinstance(recipe_type, str):
recipe_type = RecipeTypes(recipe_type).value

if not isinstance(model, Model):
model = Model.load_model_from_stub(model)

Expand Down Expand Up @@ -508,15 +507,21 @@ def recipe_type_original(self) -> bool:
:return: True if this is the original recipe that created the
model, False otherwise
"""
return self.recipe_type == RecipeTypes.ORIGINAL.value
return any(
self.recipe_type.startswith(start)
for start in [RecipeTypes.ORIGINAL.value, RecipeTypes.SPARSE.value]
)

@property
def recipe_type_transfer_learn(self) -> bool:
"""
:return: True if this is a recipe for transfer learning from the
created model, False otherwise
"""
return self.recipe_type == RecipeTypes.TRANSFER_LEARN.value
return any(
self.recipe_type.startswith(start)
for start in [RecipeTypes.TRANSFER.value, RecipeTypes.TRANSFER_LEARN.value]
)

@property
def display_name(self):
Expand Down Expand Up @@ -653,15 +658,17 @@ def download_base_framework_files(
return base_framework_files or framework_files


def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str:
def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> Optional[str]:
# check recipe type, default to original, and validate
recipe_type = stub_args.get("recipe_type")

# validate
valid_recipe_types = list(map(lambda typ: typ.value, RecipeTypes))
if recipe_type not in valid_recipe_types and recipe_type is not None:

if recipe_type is not None and not any(
recipe_type.startswith(start) for start in valid_recipe_types
):
raise ValueError(
f"Invalid recipe_type: '{recipe_type}'. "
f"Valid recipe types: {valid_recipe_types}"
f"Valid recipes must start with one of: {valid_recipe_types}"
)

return recipe_type

0 comments on commit 3c6091c

Please sign in to comment.