Skip to content

Commit

Permalink
Refactor Modality object (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Oct 10, 2024
1 parent 187fb48 commit b291ea0
Show file tree
Hide file tree
Showing 33 changed files with 394 additions and 309 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ experimentation and research for new techniques.
## Quick Start
### Installation
#### Prerequisites
The library requires Python 3.9 or later. We recommend using a virtual environment to manage dependencies. You can create
The library requires Python 3.10 or later. We recommend using a virtual environment to manage dependencies. You can create
a virtual environment using the following command:
```bash
python3 -m venv /path/to/new/virtual/environment
Expand Down
3 changes: 2 additions & 1 deletion mmlearn/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Hydra/Hydra-zen-based configurations."""

import functools
import os
import warnings
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -29,7 +30,7 @@

def _get_default_ckpt_dir() -> Any:
"""Get the default checkpoint directory."""
return SI("/checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID}")
return SI("${hydra:runtime.output_dir}/checkpoints")


_DataLoaderConf = builds(
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/datasets/chexpert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __getitem__(self, idx: int) -> Example:

return Example(
{
Modalities.RGB: image,
Modalities.RGB.name: image,
Modalities.RGB.target: label,
"qid": entry["qid"],
EXAMPLE_INDEX_KEY: idx,
Expand Down
8 changes: 4 additions & 4 deletions mmlearn/datasets/core/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

from torch.utils.data import default_collate

from mmlearn.datasets.core.example import Example
from mmlearn.datasets.core.modalities import Modalities, Modality
from mmlearn.datasets.core.modalities import Modalities


@dataclass
Expand Down Expand Up @@ -43,9 +43,9 @@ def __call__(self, examples: list[Example]) -> dict[str, Any]:

if self.batch_processors is not None:
for key, processor in self.batch_processors.items():
batch_key: Union[str, Modality] = key
batch_key: str = key
if Modalities.has_modality(key):
batch_key = Modalities.get_modality(key)
batch_key = Modalities.get_modality(key).name

if batch_key in batch:
batch_processed = processor(batch[batch_key])
Expand Down
180 changes: 84 additions & 96 deletions mmlearn/datasets/core/modalities.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Module for managing supported modalities in the library."""

import re
from typing import TYPE_CHECKING, Any, Optional
import warnings
from dataclasses import dataclass, field
from typing import Any, ClassVar, Optional

from typing_extensions import Self


_default_supported_modalities = ["rgb", "depth", "thermal", "text", "audio", "video"]
_DEFAULT_SUPPORTED_MODALITIES = ["rgb", "depth", "thermal", "text", "audio", "video"]


class Modality(str):
@dataclass
class Modality:
"""Class to represent a modality in the library.
This class is used to represent a modality in the library. It contains the name of
Expand All @@ -24,61 +27,46 @@ class Modality(str):
modality_specific_properties : Optional[dict[str, str]], optional, default=None
Additional properties specific to the modality, by default None
Attributes
----------
value : str
The name of the modality.
properties : dict[str, str]
The properties associated with the modality. By default, the properties are
`target`, `mask`, `embedding`, `masked_embedding`, and `ema_embedding`.
These default properties apply to all newly created modality types
automatically. Modality-specific properties can be added using the
`add_property` method or by passing them as a dictionary to the constructor.
Raises
------
ValueError
If the property already exists for the modality or if the format string is
invalid.
"""

_default_properties = {
"target": "{}_target",
"attention_mask": "{}_attention_mask",
"mask": "{}_mask",
"embedding": "{}_embedding",
"masked_embedding": "{}_masked_embedding",
"ema_embedding": "{}_ema_embedding",
}

if TYPE_CHECKING:

def __getattr__(self, attr: str) -> Any:
"""Get the value of the attribute."""
...

def __setattr__(self, attr: str, value: Any) -> None:
"""Set the value of the attribute."""
...

def __new__(
cls, name: str, modality_specific_properties: Optional[dict[str, str]] = None
) -> Self:
name: str
target: str = field(init=False, repr=False)
attention_mask: str = field(init=False, repr=False)
mask: str = field(init=False, repr=False)
embedding: str = field(init=False, repr=False)
masked_embedding: str = field(init=False, repr=False)
ema_embedding: str = field(init=False, repr=False)
modality_specific_properties: Optional[dict[str, str]] = field(
default=None, repr=False
)

def __post_init__(self) -> None:
"""Initialize the modality with the name and properties."""
instance = super(Modality, cls).__new__(cls, name.lower())
properties = cls._default_properties.copy()
if modality_specific_properties is not None:
properties.update(modality_specific_properties)
instance._properties = properties

for property_name, format_string in instance._properties.items():
instance._set_property_as_attr(property_name, format_string)

return instance

@property
def value(self) -> str:
"""Return the name of the modality."""
return self.__str__()
self.name = self.name.lower()
self._properties = {}

for field_name in self.__dataclass_fields__:
if field_name not in ("name", "modality_specific_properties"):
field_value = f"{self.name}_{field_name}"
self._properties[field_name] = field_value
setattr(self, field_name, field_value)

if self.modality_specific_properties is not None:
for (
property_name,
format_string,
) in self.modality_specific_properties.items():
self.add_property(property_name, format_string)

@property
def properties(self) -> dict[str, str]:
"""Return the properties associated with the modality."""
return {name: getattr(self, name) for name in self._properties}
return self._properties

def add_property(self, name: str, format_string: str) -> None:
"""Add a new property to the modality.
Expand All @@ -92,49 +80,38 @@ def add_property(self, name: str, format_string: str) -> None:
placeholder that will be replaced with the name of the modality when the
property is accessed.
Warns
-----
UserWarning
If the property already exists for the modality. It will overwrite the
existing property.
Raises
------
ValueError
If the property already exists for the modality or if the format string is
invalid.
If `format_string` is invalid. A valid format string contains at least one
placeholder enclosed in curly braces.
"""
if name in self._properties:
raise ValueError(
warnings.warn(
f"Property '{name}' already exists for modality '{super().__str__()}'."
"Will overwrite the existing property.",
category=UserWarning,
stacklevel=2,
)
self._properties[name] = format_string
self._set_property_as_attr(name, format_string)

def _set_property_as_attr(self, name: str, format_string: str) -> None:
"""Set the property as an attribute of the modality."""
if not _is_format_string(format_string):
raise ValueError(
f"Invalid format string '{format_string}' for property "
f"'{name}' of modality '{super().__str__()}'."
)
setattr(self, name, format_string.format(self.value))

self._properties[name] = format_string.format(self.name)
setattr(self, name, self._properties[name])

def __str__(self) -> str:
"""Return the object as a string."""
return self.lower()

def __repr__(self) -> str:
"""Return the string representation of the modality."""
return f"<Modality: {self.upper()}>"

def __hash__(self) -> int:
"""Return the hash of the modality name and properties."""
return hash((self.value, tuple(self._properties.items())))

def __eq__(self, other: object) -> bool:
"""Check if two modality types are equal.
Two modality types are equal if they have the same name and properties.
"""
return isinstance(other, Modality) and (
(self.__str__() == other.__str__())
and (self._properties == other._properties)
)
return self.name.lower()


class ModalityRegistry:
Expand All @@ -146,16 +123,15 @@ class ModalityRegistry:
ensure that there is only one instance of the registry in the library.
"""

_instance = None
_instance: ClassVar[Any] = None
_modality_registry: dict[str, Modality] = {}

def __new__(cls) -> Self:
"""Create a new instance of the class if it does not exist."""
if cls._instance is None:
cls._instance = super(ModalityRegistry, cls).__new__(cls)
cls._instance._modality_registry = {} # type: ignore[attr-defined]
for modality in _default_supported_modalities:
cls._instance.register_modality(modality)
return cls._instance
cls._instance = super().__new__(cls)
cls._instance._modality_registry = {}
return cls._instance # type: ignore[no-any-return]

def register_modality(
self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
Expand All @@ -169,13 +145,19 @@ def register_modality(
modality_specific_properties : Optional[dict[str, str]], optional, default=None
Additional properties specific to the modality.
Raises
------
ValueError
If the modality already exists in the registry.
Warns
-----
UserWarning
If the modality already exists in the registry. It will overwrite the
existing modality.
"""
if name.lower() in self._modality_registry:
raise ValueError(f"Modality '{name}' already exists in the registry.")
warnings.warn(
f"Modality '{name}' already exists in the registry. Overwriting...",
category=UserWarning,
stacklevel=2,
)

name = name.lower()
modality = Modality(name, modality_specific_properties)
Expand All @@ -194,18 +176,21 @@ def add_default_property(self, name: str, format_string: str) -> None:
placeholder that will be replaced with the name of the modality when the
property is accessed.
Warns
-----
UserWarning
If the property already exists for the default properties. It will
overwrite the existing property.
Raises
------
ValueError
If the property already exists for the default properties or if the format
string is invalid.
If the format string is invalid. A valid format string contains at least one
placeholder enclosed in curly braces.
"""
for modality in self._modality_registry.values():
modality.add_property(name, format_string)

# add the property to the default properties for new modalities
Modality._default_properties[name.lower()] = format_string

def has_modality(self, name: str) -> bool:
"""Check if the modality exists in the registry.
Expand Down Expand Up @@ -234,7 +219,7 @@ def get_modality(self, name: str) -> Modality:
Modality
The modality object from the registry.
"""
return self._modality_registry[name.lower()] # type: ignore[index,return-value]
return self._modality_registry[name.lower()]

def get_modality_properties(self, name: str) -> dict[str, str]:
"""Get the properties of a modality from the registry.
Expand Down Expand Up @@ -264,7 +249,7 @@ def list_modalities(self) -> list[Modality]:
def __getattr__(self, name: str) -> Modality:
"""Access a modality as an attribute by its name."""
if name.lower() in self._modality_registry:
return self._modality_registry[name.lower()] # type: ignore[index,return-value]
return self._modality_registry[name.lower()]
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
Expand Down Expand Up @@ -292,3 +277,6 @@ def _is_format_string(string: str) -> bool:


Modalities = ModalityRegistry()

for modality in _DEFAULT_SUPPORTED_MODALITIES:
Modalities.register_modality(modality)
2 changes: 1 addition & 1 deletion mmlearn/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __getitem__(self, index: int) -> Example:
image, target = super().__getitem__(index)
example = Example(
{
Modalities.RGB: image,
Modalities.RGB.name: image,
Modalities.RGB.target: target,
EXAMPLE_INDEX_KEY: index,
}
Expand Down
4 changes: 2 additions & 2 deletions mmlearn/datasets/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def __getitem__(self, idx: int) -> Example:

return Example(
{
Modalities.AUDIO: waveform,
Modalities.TEXT: transcript,
Modalities.AUDIO.name: waveform,
Modalities.TEXT.name: transcript,
EXAMPLE_INDEX_KEY: idx,
},
)
10 changes: 5 additions & 5 deletions mmlearn/datasets/llvip.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __getitem__(self, idx: int) -> Example:
rgb_image = PILImage.open(rgb_image_path).convert("RGB")
ir_image = PILImage.open(ir_image_path).convert("L")

sample = Example(
example = Example(
{
Modalities.RGB: self.transform(rgb_image),
Modalities.THERMAL: self.transform(ir_image),
Modalities.RGB.name: self.transform(rgb_image),
Modalities.THERMAL.name: self.transform(ir_image),
EXAMPLE_INDEX_KEY: idx,
},
)
Expand All @@ -85,11 +85,11 @@ def __getitem__(self, idx: int) -> Example:
.replace("train", "")
)
annot = self._get_bbox(annot_path)
sample["annotation"] = {
example["annotation"] = {
"bboxes": torch.from_numpy(annot["bboxes"]),
"labels": torch.from_numpy(annot["labels"]),
}
return sample
return example

def _get_bbox(self, filename: str) -> Dict[str, np.ndarray]:
"""Parse the XML file to get bounding boxes and labels.
Expand Down
Loading

0 comments on commit b291ea0

Please sign in to comment.