diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 524a386..9fcdb3a 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -8,7 +8,7 @@ from typing import Optional import torch -from fairseq2.models.utils.arch_registry import ArchitectureRegistry +from fairseq2.config_registry import ConfigRegistry from fairseq2.typing import DataType, Device from torch import nn @@ -38,4 +38,4 @@ class MutoxConfig: input_size: int -mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") +mutox_archs = ConfigRegistry[MutoxConfig]() diff --git a/sonar/models/mutox/loader.py b/sonar/models/mutox/loader.py index 20d764d..b6ce64a 100644 --- a/sonar/models/mutox/loader.py +++ b/sonar/models/mutox/loader.py @@ -4,10 +4,10 @@ # This source code is licensed under the license found in the # MIT_LICENSE file in the root directory of this source tree. -import typing as tp +from typing import Any, Dict -from fairseq2.assets import asset_store, download_manager -from fairseq2.models.utils import ConfigLoader, ModelLoader +from fairseq2.models.config_loader import StandardModelConfigLoader +from fairseq2.models.loader import StandardModelLoader, load_model from .builder import create_mutox_model from .classifier import MutoxClassifier, MutoxConfig, mutox_archs @@ -23,8 +23,8 @@ def _base_mutox() -> MutoxConfig: def convert_mutox_checkpoint( - checkpoint: tp.Mapping[str, tp.Any], config: MutoxConfig -) -> tp.Mapping[str, tp.Any]: + checkpoint: Dict[str, Any], config: MutoxConfig +) -> Dict[str, Any]: new_dict = {} for key in checkpoint: if key.startswith("model_all."): @@ -32,13 +32,13 @@ def convert_mutox_checkpoint( return {"model": new_dict} -load_mutox_config = ConfigLoader[MutoxConfig](asset_store, mutox_archs) +load_mutox_config = StandardModelConfigLoader(family="mutox", config_kls=MutoxConfig, arch_configs=mutox_archs) - -load_mutox_model = ModelLoader[MutoxClassifier, MutoxConfig]( - asset_store, - download_manager, - load_mutox_config, - create_mutox_model, - convert_mutox_checkpoint, +load_mutox_model = StandardModelLoader( + config_loader=load_mutox_config, + factory=create_mutox_model, + checkpoint_converter=convert_mutox_checkpoint, + restrict_checkpoints=False, ) + +load_model.register("mutox_classifier", load_mutox_model)