Skip to content

Commit

Permalink
fix(introspection): support inheritance in get_slots
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdoret committed Jan 12, 2024
1 parent 35b2249 commit 7819bf6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
6 changes: 3 additions & 3 deletions modo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def prompt_for_slot(slot_name: str, prefix: str = ""):


def prompt_for_slots(
target_class: str,
target_class: type,
) -> dict[str, Any]:
"""Prompt the user to provide values for the slots of input class."""

Expand Down Expand Up @@ -117,7 +117,7 @@ def create(
elif meta:
obj = json_loader.loads(meta, target_class=model.MODO)
else:
filled = prompt_for_slots("MODO")
filled = prompt_for_slots(model.MODO)
obj = model.MODO(**filled)

# Dump object to zarr metadata
Expand Down Expand Up @@ -197,7 +197,7 @@ def add(
elif meta:
obj = json_loader.loads(meta, target_class=target_class)
else:
filled = prompt_for_slots(target_class.__name__)
filled = prompt_for_slots(target_class)
obj = target_class(**filled)

modo = MODO(object_directory)
Expand Down
13 changes: 5 additions & 8 deletions modo/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,16 @@ def load_prefixmap() -> Any:
return SchemaView(SCHEMA_PATH, merge_imports=False).schema.prefixes


def get_slots(target_class: str, required_only=False) -> list[str]:
def get_slots(target_class: type, required_only=False) -> list[str]:
"""Return a list of required slots for a class."""
required = []
class_slots = load_schema().get_class(target_class).slots
if not class_slots:
return required
slots = []
class_slots = target_class.__match_args__

# NOTE: May need to check inheritance and slot_usage
for slot_name in class_slots:
if not required_only or load_schema().get_slot(slot_name).required:
required.append(slot_name)
slots.append(slot_name)

return required
return slots


def instance_to_graph(instance) -> Graph:
Expand Down

0 comments on commit 7819bf6

Please sign in to comment.