Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
paradin committed Nov 30, 2024
1 parent 467f06b commit 03172c8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
45 changes: 29 additions & 16 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,30 +1180,43 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]:

# Receive model infos of workers
@log_async(logger=logger)
async def sync_models(self, worker_address: str, model_desc: Dict[str, Dict[str, Any]]): # model_uid : ModelDescription{"address"}
async def sync_models(
self, worker_address: str, model_desc: Dict[str, Dict[str, Any]]
): # model_uid : ModelDescription{"address"}
for replica_model_uid, desc_dict in model_desc.items():
# Rebuild self._replica_model_uid_to_worker
if replica_model_uid in self._replica_model_uid_to_worker:
continue

model_name = desc_dict["model_name"] if "model_name" in desc_dict else ""
model_version = desc_dict["model_version"] if "model_version" in desc_dict else ""
logger.debug(f"Receive model replica: {replica_model_uid} {worker_address} {model_name}")

assert (worker_address in self._worker_address_to_worker), f"Worker {worker_address} not exists when sync_models"

self._replica_model_uid_to_worker[replica_model_uid] = self._worker_address_to_worker[worker_address]


model_version = (
desc_dict["model_version"] if "model_version" in desc_dict else ""
)
logger.debug(
f"Receive model replica: {replica_model_uid} {worker_address} {model_name}"
)

assert (
worker_address in self._worker_address_to_worker
), f"Worker {worker_address} not exists when sync_models"

self._replica_model_uid_to_worker[
replica_model_uid
] = self._worker_address_to_worker[worker_address]

# Rebuild self._model_uid_to_replica_info
model_uid, rep_id = parse_replica_model_uid(replica_model_uid)
replica = rep_id+1
replica = rep_id + 1
if model_uid not in self._model_uid_to_replica_info:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(replica=replica, scheduler=itertools.cycle(range(replica)))
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)
else:
if replica > self._model_uid_to_replica_info[model_uid].replica:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(replica=replica, scheduler=itertools.cycle(range(replica)))

self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)

# Rebuild self._status_guard_ref
instance_info = InstanceInfo(
model_name=model_name,
Expand All @@ -1213,9 +1226,9 @@ async def sync_models(self, worker_address: str, model_desc: Dict[str, Dict[str,
replica=replica,
status=LaunchStatus.READY.name,
instance_created_ts=int(time.time()),
)
)
await self._status_guard_ref.set_instance_info(model_uid, instance_info)

def is_local_deployment(self) -> bool:
# TODO: temporary.
return (
Expand Down
10 changes: 6 additions & 4 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,17 +335,19 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
# Newly started (or restarted), has no model, notify supervisor
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

# Reconnect to Newly started supervisor, has running models
if add_worker and len(self._model_uid_to_model) > 0:
# Reconnect to Newly started supervisor, notify supervisor
await self._supervisor_ref.add_worker(self.address)
# Sync replical model infos
# Sync replica model infos
running_models = {}
running_models.update(await self.list_models())
await self._supervisor_ref.sync_models(self.address, running_models)
logger.info(f"Connected to supervisor as a old worker with {len(running_models)} models")

logger.info(
f"Connected to supervisor as a old worker with {len(running_models)} models"
)

self._status_guard_ref = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.default_uid()
)
Expand Down

0 comments on commit 03172c8

Please sign in to comment.