diff --git a/src/fairseq2/assets/__init__.py b/src/fairseq2/assets/__init__.py index a32315ed6..6df7d0bc3 100644 --- a/src/fairseq2/assets/__init__.py +++ b/src/fairseq2/assets/__init__.py @@ -26,6 +26,9 @@ from fairseq2.assets.metadata_provider import ( FileAssetMetadataProvider as FileAssetMetadataProvider, ) +from fairseq2.assets.metadata_provider import ( + InProcAssetMetadataProvider as InProcAssetMetadataProvider, +) from fairseq2.assets.store import AssetStore as AssetStore from fairseq2.assets.store import ProviderBackedAssetStore as ProviderBackedAssetStore from fairseq2.assets.store import asset_store as asset_store diff --git a/src/fairseq2/assets/metadata_provider.py b/src/fairseq2/assets/metadata_provider.py index 6a9e16550..27174cf62 100644 --- a/src/fairseq2/assets/metadata_provider.py +++ b/src/fairseq2/assets/metadata_provider.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, final +from typing import Any, Dict, Optional, Sequence, final import yaml from typing_extensions import NoReturn @@ -132,6 +132,47 @@ def clear_cache(self) -> None: self._cache = None +@final +class InProcAssetMetadataProvider(AssetMetadataProvider): + """Provides asset metadata stored in-memory.""" + + _metadata: Dict[str, Dict[str, Any]] + + def __init__(self, metadata: Sequence[Dict[str, Any]]) -> None: + self._metadata = {} + + for idx, m in enumerate(metadata): + try: + name = m["name"] + except KeyError: + raise AssetMetadataError( + f"The asset metadata at index {idx} in `metadata` does not have a name." + ) + + if not isinstance(name, str): + raise AssetMetadataError( + f"The asset metadata at index {idx} in `metadata` has an invalid name." + ) + + if name in self._metadata: + raise AssetMetadataError(f"Two assets have the same name '{name}'.") + + self._metadata[name] = m + + @finaloverride + def get_metadata(self, name: str) -> Dict[str, Any]: + try: + return deepcopy(self._metadata[name]) + except KeyError: + raise AssetNotFoundError( + f"An asset with the name '{name}' cannot be found." + ) + + @finaloverride + def clear_cache(self) -> None: + pass + + class AssetNotFoundError(AssetError): """Raised when an asset cannot be found."""