Skip to content

Commit

Permalink
mutox with new fs2
Browse files Browse the repository at this point in the history
  • Loading branch information
artyomko committed Dec 10, 2024
1 parent 91ce6da commit 76db1fc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions sonar/models/mutox/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional

import torch
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataType, Device
from torch import nn

Expand Down Expand Up @@ -38,4 +38,4 @@ class MutoxConfig:
input_size: int


mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier")
mutox_archs = ConfigRegistry[MutoxConfig]()
26 changes: 13 additions & 13 deletions sonar/models/mutox/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

import typing as tp
from typing import Any, Dict

from fairseq2.assets import asset_store, download_manager
from fairseq2.models.utils import ConfigLoader, ModelLoader
from fairseq2.models.config_loader import StandardModelConfigLoader
from fairseq2.models.loader import StandardModelLoader, load_model

from .builder import create_mutox_model
from .classifier import MutoxClassifier, MutoxConfig, mutox_archs
Expand All @@ -23,22 +23,22 @@ def _base_mutox() -> MutoxConfig:


def convert_mutox_checkpoint(
checkpoint: tp.Mapping[str, tp.Any], config: MutoxConfig
) -> tp.Mapping[str, tp.Any]:
checkpoint: Dict[str, Any], config: MutoxConfig
) -> Dict[str, Any]:
new_dict = {}
for key in checkpoint:
if key.startswith("model_all."):
new_dict[key] = checkpoint[key]
return {"model": new_dict}


load_mutox_config = ConfigLoader[MutoxConfig](asset_store, mutox_archs)
load_mutox_config = StandardModelConfigLoader(family="mutox", config_kls=MutoxConfig, arch_configs=mutox_archs)


load_mutox_model = ModelLoader[MutoxClassifier, MutoxConfig](
asset_store,
download_manager,
load_mutox_config,
create_mutox_model,
convert_mutox_checkpoint,
load_mutox_model = StandardModelLoader(
config_loader=load_mutox_config,
factory=create_mutox_model,
checkpoint_converter=convert_mutox_checkpoint,
restrict_checkpoints=False,
)

load_model.register("mutox_classifier", load_mutox_model)

0 comments on commit 76db1fc

Please sign in to comment.