Skip to content

Commit

Permalink
Teach admin server how to activate dormant model server based on repo…
Browse files Browse the repository at this point in the history
…rted stats.

PiperOrigin-RevId: 651227398
Change-Id: I79194979d25eb168388a885b51f4ba65a20b1904
  • Loading branch information
Sax Authors authored and copybara-github committed Jul 11, 2024
1 parent dd75f08 commit 777b80b
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 3 deletions.
21 changes: 21 additions & 0 deletions saxml/admin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ go_library(
":protobuf",
":state",
":validator",
":wakerpolicy",
# unused internal flag dependency,
"//saxml/common:errors",
"//saxml/common:eventlog",
Expand Down Expand Up @@ -150,3 +151,23 @@ go_test(
"//saxml/protobuf:admin_go_proto_grpc",
],
)

go_library(
name = "wakerpolicy",
srcs = ["waker_policy.go"],
deps = [
":state",
"//saxml/common:naming",
],
)

go_test(
name = "wakerpolicy_test",
srcs = ["waker_policy_test.go"],
library = ":wakerpolicy",
deps = [
":state",
"//saxml/common:naming",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
18 changes: 18 additions & 0 deletions saxml/admin/mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"saxml/admin/protobuf"
"saxml/admin/state"
"saxml/admin/validator"
"saxml/admin/wakerpolicy"
"saxml/common/errors"
"saxml/common/eventlog"
"saxml/common/naming"
Expand Down Expand Up @@ -889,6 +890,23 @@ func (m *Mgr) Refresh(ctx context.Context) {
// unpublished before the ComputeAssignment call available for use
// again.
m.freeUnpublishedNames(pendingUnpublished)

// Walk though all known model servers, wake up servers based on waking policy to balance the
// load within servers in the same cell.
wakerPolicy := wakerpolicy.NewWakerPolicy()
m.mu.RLock()
for addr, state := range m.modelets {
wakerPolicy.AddServerStatus(string(addr), state)
}
m.mu.RUnlock()

candidates := wakerPolicy.Decide()

m.mu.RLock()
for addr := range candidates {
m.modelets[modeletAddr(rune(addr))].WakeUp(ctx)
}
m.mu.RUnlock()
}

// Restore restores the manager state from its backing store.
Expand Down
28 changes: 25 additions & 3 deletions saxml/admin/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ var (
// Various RPC timeout thresholds.
dialTimeout = time.Second * 10
getStatusTimeout = time.Second * 10
wakeUpTimeout = time.Second * 10
)

// SetOptionsForTesting updates refreshPeriod so that during tests, the state issues more frequent
Expand Down Expand Up @@ -123,9 +124,10 @@ type ModelWithStatus struct {
Info ModelInfo
}

// ServerStatus tracks a model server usability state.
// ServerStatus tracks a model server usability state and early rejection stats.
type ServerStatus struct {
IsDormant bool
IsDormant bool
EarlyRejectionErrorsPerSecond [2]float32 // perserve most recent two values, [0] is the latest.
}

func (m *ModelWithStatus) clone() *ModelWithStatus {
Expand All @@ -138,6 +140,7 @@ const (
load actionKind = iota
unload
update
wakeup
)

func (k actionKind) String() string {
Expand All @@ -148,12 +151,14 @@ func (k actionKind) String() string {
return "unload"
case update:
return "update"
case wakeup:
return "wakeup"
default:
return "invalid"
}
}

// action represents an administrative method the model server can perform on a model.
// action represents an administrative method the model server can perform on a model or the server.
type action struct {
kind actionKind
ctx context.Context
Expand Down Expand Up @@ -343,6 +348,13 @@ func (s *State) act(a *action) {
s.wanted[a.fullName] = a.model
s.mu.Unlock()
}
case wakeup:
log.V(0).Infof("Waking up server %v", s.Addr)
ctx, cancel := context.WithTimeout(a.ctx, wakeUpTimeout)
defer cancel()
if _, err := s.client.WakeUp(ctx, &mpb.WakeUpRequest{}); err != nil {
log.WarningContextf(ctx, "Failed to wake up server %v (%v)", s.Addr, err)
}
default:
log.Warningf("Unknown action type %T", a)
}
Expand Down Expand Up @@ -396,10 +408,20 @@ func (s *State) getStatus(ctx context.Context) (map[naming.ModelFullName]*ModelI
if serverStatus := res.GetServerStatus(); serverStatus != nil {
s.lastReportedStatus.IsDormant =
serverStatus.GetState() == mpb.GetStatusResponse_ServerStatus_DORMANT
errorRates := s.lastReportedStatus.EarlyRejectionErrorsPerSecond
errorRates[1] = errorRates[0] // [1] always perserves the historical value.
errorRates[0] = serverStatus.GetStats().GetEarlyRejectionErrorsPerSecond()
}
return seen, nil
}

// WakeUp wakes up a dormant server.
func (s *State) WakeUp(ctx context.Context) {
s.mu.Lock()
defer s.mu.Unlock()
s.queue <- &action{wakeup, ctx, naming.ModelFullName{}, nil, nil}
}

// initialize sets wanted and seen models of a just created State instance from a running server.
func (s *State) initialize(ctx context.Context, modelFinder ModelFinder) error {
seen, err := s.getStatus(ctx)
Expand Down
114 changes: 114 additions & 0 deletions saxml/admin/waker_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package wakerpolicy decides which dormant servers should be woken up based on the current model server status.
package wakerpolicy

import (
"fmt"

"saxml/admin/state"
"saxml/common/naming"
)

// serverAddr represents a model server address. E.g., 1.2.3.4:14001.
type serverAddr string

// modelFullName identifies a model in the form of /sax/<cell>/<model>.
type modelName = naming.ModelFullName

// ServerStatusFeeder provides the current model server status.
type ServerStatusFeeder interface {
SeenModels() map[naming.ModelFullName]*state.ModelWithStatus
LastReportedServerStatus() state.ServerStatus
}

type dormantStats struct {
totalServerCount int
dormantServerCount int
candidateToWakeUp serverAddr
sumErrorsRateChange float32
}

// WakerPolicy models per-model load stats and decides which dormant servers should be woken up.
type WakerPolicy struct {
perModelDormantServerStats map[modelName]*dormantStats
}

// NewWakerPolicy creates a new WakerPolicy instance.
func NewWakerPolicy() *WakerPolicy {
return &WakerPolicy{
perModelDormantServerStats: make(map[modelName]*dormantStats),
}
}

// AddServerStatus adds a model server status to the policy module for consideration.
func (w *WakerPolicy) AddServerStatus(server string, statusFeeder ServerStatusFeeder) {
addr := serverAddr(server)
lastReported := statusFeeder.LastReportedServerStatus()
for name := range statusFeeder.SeenModels() {
if w.perModelDormantServerStats[name] == nil {
w.perModelDormantServerStats[name] = &dormantStats{}
}
stats := w.perModelDormantServerStats[name]
stats.totalServerCount++
if lastReported.IsDormant {
stats.dormantServerCount++
stats.candidateToWakeUp = max(stats.candidateToWakeUp, addr)
hist := lastReported.EarlyRejectionErrorsPerSecond
stats.sumErrorsRateChange = hist[0] - hist[1]
}
}
}

// Decide returns a list of server addresses that should be woken up.
func (w *WakerPolicy) Decide() []string {
// Policy for wake-up a candidate server for a model: "the first derivative of rejection rate on
// dormant server > the number of active servers."
//
// On the one hand, this policy is based on a simple intuition "if we see increasing rejection
// error rate per-dormant-server, this suggests we could use one more server." Therefore on the
// opposite side, if we see decreasing rejection error rate, this suggests we have sufficient
// active servers and it's no-ops.
//
// On the other hand, the "# active servers" is based on the fact that: "when k client queries and
// there is k active servers, each model server should see ≤k rejection error because of retry."
// So if we see error rate change >k per dormant server, we know there is more traffic than
// current active servers may handle.
//
// The goal is to save resources by minimizing amount of dormant server waking-up, for example,
// - No-ops when there is no early rejection error to a model.
// - No-ops if previous server is not finished waking-up yet.
// - No-ops if early rejection error rate is not increasing.
candidates := map[serverAddr]bool{}
for mname, stats := range w.perModelDormantServerStats {
logMsg := fmt.Sprintf("Running server wake policy against model: %v with stats: %v \n", mname, stats)
if stats.dormantServerCount > 0 {
averageErrorRateChange := stats.sumErrorsRateChange / float32(stats.dormantServerCount)
activeServersCount := float32(stats.totalServerCount - stats.dormantServerCount)
shouldWakeUp := averageErrorRateChange > activeServersCount
if shouldWakeUp {
addr := stats.candidateToWakeUp
candidates[addr] = true
}
logMsg += fmt.Sprintf("Decision: %v as seeing avg_error_rate_delta: %v, active_server_count: %v",
shouldWakeUp, averageErrorRateChange, activeServersCount)
}
}
servers := []string{}
for cand := range candidates {
servers = append(servers, string(cand))
}
return servers
}
Loading

0 comments on commit 777b80b

Please sign in to comment.