Skip to content

Commit

Permalink
apply_model_overrides should follow the base signature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702894126
Change-Id: Ifc240e6dcd28b9b661ac791cb93940cf1467f89e
  • Loading branch information
alanwaketan authored and copybara-github committed Dec 5, 2024
1 parent cb393fb commit 561787f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions saxml/server/pax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,6 @@ pytype_strict_library(
"//third_party/py/praxis:pax_fiddle",
"//third_party/py/praxis:py_utils",
"//third_party/py/praxis:pytypes",
"//third_party/py/typing_extensions",
],
)
7 changes: 6 additions & 1 deletion saxml/server/pax/union_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from praxis import pytypes
from saxml.server.pax import servable_model
from saxml.server.pax import servable_model_params
from typing_extensions import override

PRNGKey = pytypes.PRNGKey
NestedMap = py_utils.NestedMap
Expand Down Expand Up @@ -69,10 +70,14 @@ def task(self) -> pax_fiddle.Config[base_task.BaseTask]:
def datasets(self) -> List[pax_fiddle.Config[base_input.BaseInput]]:
raise NotImplementedError('should not be called')

@override
@classmethod
def apply_model_overrides(cls, overrides: Dict[str, Any]) -> None:
def apply_model_overrides(
cls, overrides: Dict[str, Any]
) -> type['UnionModelParams']:
"""Delays the model overrides until child creation."""
cls.overrides = overrides
return cls


class UnionModel(servable_model.ServableModel):
Expand Down

0 comments on commit 561787f

Please sign in to comment.