Skip to content

Commit

Permalink
Fix a bug in the usage of apply_model_overrides in UnionModel.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704338054
Change-Id: I3b75ea79f377eb6ca9cf318d00bac8374589ab55
  • Loading branch information
Sax Authors authored and copybara-github committed Dec 9, 2024
1 parent 561787f commit f6dc38a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion saxml/server/pax/union_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def load_state(
raise ValueError('No children in UnionModelParams')
self._models = []
for child in children:
child = child.apply_model_overrides(union_config.overrides)
child_inst = child()
child_inst.apply_model_overrides(union_config.overrides)
self._models.append(child_inst.create_model(self.primary_process_id))
return self._models[0].load_state(checkpoint_path, prng_key, precompile)

Expand Down
4 changes: 2 additions & 2 deletions saxml/server/servable_model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Self, Tuple, Union

from absl import logging
import numpy as np
Expand Down Expand Up @@ -143,7 +143,7 @@ def sax_registration_name(cls) -> Optional[str]:
@classmethod
def apply_model_overrides(
cls, overrides: Dict[str, Any]
) -> type['ServableModelParams']:
) -> type[Self]:
"""Applies model config overrides received from Publish.
The default handling of overrides is as follows:
Expand Down

0 comments on commit f6dc38a

Please sign in to comment.