Skip to content

Commit

Permalink
[Serving] Fix warmup failed bug when use session_group. (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanshanpt authored and liutongxuan committed Nov 2, 2022
1 parent a7c37f9 commit 33a6f53
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
12 changes: 6 additions & 6 deletions serving/processor/serving/model_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,10 @@ Status LocalSessionInstance::Warmup(
int left_try_count = WARMUP_COUNT;
while (left_try_count > 0) {
if (warmup_session) {
s = warmup_session->LocalPredict(
s = warmup_session->Warmup(
call.request, call.response);
} else {
s = session_mgr_->LocalPredict(
s = session_mgr_->Warmup(
call.request, call.response);
}
if (!s.ok()) return s;
Expand Down Expand Up @@ -563,11 +563,11 @@ Status RemoteSessionInstance::Warmup(
int left_try_count = WARMUP_COUNT;
while (left_try_count > 0) {
if (warmup_session) {
s = warmup_session->LocalPredict(
call.request, call.response);
s = warmup_session->Warmup(
call.request, call.response, false);
} else {
s = session_mgr_->LocalPredict(
call.request, call.response);
s = session_mgr_->Warmup(
call.request, call.response, false);
}
if (!s.ok()) return s;

Expand Down
53 changes: 48 additions & 5 deletions serving/processor/serving/model_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ int ModelSession::GetServingSessionId() {
}

Status ModelSession::Predict(Request& req, Response& resp) {
return InternalPredict(req, resp, GetServingSessionId());
}

Status ModelSession::Predict(Request& req, Response& resp,
int sess_id) {
return InternalPredict(req, resp, sess_id);
}

Status ModelSession::InternalPredict(Request& req, Response& resp,
int sess_id) {
if (is_local_) {
return Status(error::Code::INTERNAL,
"Local sparse storage, please use LocalPredict.");
Expand All @@ -278,17 +288,31 @@ Status ModelSession::Predict(Request& req, Response& resp) {
// TODO: which session selected to run on, add some policy here
status = session_group_->Run(run_options, req.inputs,
req.output_tensor_names, {}, &resp.outputs,
&run_metadata, GetServingSessionId());
&run_metadata, sess_id);
Tracer::GetTracer()->GenTimeline(run_metadata);
} else {
status = session_group_->Run(req.inputs, req.output_tensor_names,
{}, &resp.outputs, GetServingSessionId());
{}, &resp.outputs, sess_id);
}
--counter_;
return status;
}

Status ModelSession::LocalPredict(Request& req, Response& resp) {
Status ModelSession::LocalPredict(Request& req,
Response& resp) {
return InternalLocalPredict(req, resp,
GetServingSessionId());
}

Status ModelSession::LocalPredict(Request& req,
Response& resp,
int sess_id) {
return InternalLocalPredict(req, resp, sess_id);
}

Status ModelSession::InternalLocalPredict(Request& req,
Response& resp,
int sess_id) {
if (!is_local_) {
return Status(error::Code::INTERNAL,
"Remote sparse storage, please use Predict.");
Expand All @@ -302,16 +326,31 @@ Status ModelSession::LocalPredict(Request& req, Response& resp) {
// TODO: which session selected to run on, add some policy here
status = session_group_->Run(run_options, req.inputs,
req.output_tensor_names, {}, &resp.outputs,
&run_metadata, GetServingSessionId());
&run_metadata, sess_id);
Tracer::GetTracer()->GenTimeline(run_metadata);
} else {
status = session_group_->Run(req.inputs, req.output_tensor_names,
{}, &resp.outputs, GetServingSessionId());
{}, &resp.outputs, sess_id);
}
--counter_;
return status;
}

Status ModelSession::Warmup(Request& req, Response& resp, bool local) {
int N = session_group_->GetSessionNum();
for (int i = 0; i < N; ++i) {
Status s;
if (local) {
s = LocalPredict(req, resp, i);
} else {
s = Predict(req, resp, i);
}
if (!s.ok()) return s;
}

return Status::OK();
}

Status ModelSessionMgr::Predict(Request& req, Response& resp) {
return serving_session_->Predict(req, resp);
}
Expand All @@ -320,6 +359,10 @@ Status ModelSessionMgr::LocalPredict(Request& req, Response& resp) {
return serving_session_->LocalPredict(req, resp);
}

Status ModelSessionMgr::Warmup(Request& req, Response& resp, bool local) {
return serving_session_->Warmup(req, resp, local);
}

Status ModelSessionMgr::CreateModelSession(
const Version& version, const char* ckpt_name,
IFeatureStoreMgr* sparse_storage, bool is_incr_ckpt,
Expand Down
6 changes: 6 additions & 0 deletions serving/processor/serving/model_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ struct ModelSession {
virtual ~ModelSession();

Status Predict(Request& req, Response& resp);
Status Predict(Request& req, Response& resp, int sess_id);
Status LocalPredict(Request& req, Response& resp);
Status LocalPredict(Request& req, Response& resp, int sess_id);
Version GetVersion() {return version_;}
void UpdateVersion(const Version& v) { version_ = v; }
Session* GetSession();
Status Warmup(Request& req, Response& resp, bool local=true);

SessionGroup* session_group_ = nullptr;
SelectSessionPolicy select_session_policy_ =
Expand All @@ -54,6 +57,8 @@ struct ModelSession {

private:
int GetServingSessionId();
Status InternalPredict(Request& req, Response& resp, int sess_id);
Status InternalLocalPredict(Request& req, Response& resp, int sess_id);
};

class ModelSessionMgr {
Expand All @@ -64,6 +69,7 @@ class ModelSessionMgr {

Status Predict(Request& req, Response& resp);
Status LocalPredict(Request& req, Response& resp);
Status Warmup(Request& req, Response& resp, bool local=true);

Status CreateModelSession(
const Version& version,
Expand Down

0 comments on commit 33a6f53

Please sign in to comment.