Skip to content

Commit

Permalink
Deprecate modelet addresses in List responses and add num_active_repl…
Browse files Browse the repository at this point in the history
…icas.

PiperOrigin-RevId: 581459169
Change-Id: I3498b5a2b579a4f358bae75d3f32a0e6c3355c7c
  • Loading branch information
jiawenhao authored and copybara-github committed Nov 11, 2023
1 parent e3bc48f commit 6ce5b07
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 69 deletions.
39 changes: 20 additions & 19 deletions saxml/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,32 +180,33 @@ func (s *Server) List(ctx context.Context, in *pb.ListRequest) (*pb.ListResponse
return &pb.ListResponse{PublishedModels: s.Mgr.ListAll()}, nil
}

func (s *Server) findAddresses(ctx context.Context, modelFullName string) ([]*pb.JoinedModelServer, error) {
if err := validator.ValidateModelFullName(modelFullName, s.saxCell); err != nil {
return nil, err
}
fullName, err := naming.NewModelFullName(modelFullName)
if err != nil {
return nil, err
}
func (s *Server) locate(ctx context.Context, modelFullName string) ([]*pb.JoinedModelServer, error) {
// Locate joined model servers for one model specifically asked about.
if modelFullName != "" {
if err := validator.ValidateModelFullName(modelFullName, s.saxCell); err != nil {
return nil, err
}
fullName, err := naming.NewModelFullName(modelFullName)
if err != nil {
return nil, err
}

addrs, err := s.Mgr.FindAddresses(fullName)
if err != nil {
return nil, err
pubModel, err := s.Mgr.List(fullName)
if err != nil {
return nil, err
}
addrs := pubModel.GetModeletAddresses()
return s.Mgr.LocateSome(addrs)
}
return s.Mgr.LocateSome(addrs)

// List all joined model servers if no model is specifically asked about.
return s.Mgr.LocateAll()
}

func (s *Server) Stats(ctx context.Context, in *pb.StatsRequest) (*pb.StatsResponse, error) {
modelFullName := in.GetModelId()

var servers []*pb.JoinedModelServer
var err error
if modelFullName == "" {
servers, err = s.Mgr.LocateAll()
} else {
servers, err = s.findAddresses(ctx, modelFullName)
}
servers, err := s.locate(ctx, modelFullName)
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions saxml/admin/admin_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ func (s *Server) handleModel(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("Failed to list model %q: %v", modelFullName.ModelFullName(), err), http.StatusInternalServerError)
return
}
addrs, err := s.Mgr.FindAddresses(modelFullName)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to find model %q: %v", modelFullName.ModelFullName(), err), http.StatusInternalServerError)
return
}
addrs := model.GetModeletAddresses()
servers, err := s.Mgr.LocateSome(addrs)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to locate servers %v: %v", addrs, err), http.StatusInternalServerError)
Expand Down
24 changes: 6 additions & 18 deletions saxml/admin/mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ func (m *Mgr) Unpublish(fullName modelFullName) error {
}

func (m *Mgr) makePublishedModelLocked(fullName modelFullName, model *apb.Model) *apb.PublishedModel {
addrs := []string{}
for _, addr := range m.assignment[fullName] {
addrs = append(addrs, string(addr))
}
cloned := proto.Clone(model).(*apb.Model)
// Clean Uuid field to not expose it to users.
cloned.Uuid = nil
return &apb.PublishedModel{
Model: cloned,
NumActiveReplicas: int32(len(m.assignment[fullName])),
Model: cloned,
ModeletAddresses: addrs,
}
}

Expand Down Expand Up @@ -288,22 +292,6 @@ func (m *Mgr) FindModel(fullName modelFullName) *apb.Model {
return nil
}

// FindAddresses returns the list of servers assigned to a model.
func (m *Mgr) FindAddresses(fullName modelFullName) ([]string, error) {
m.mu.RLock()
defer m.mu.RUnlock()

if _, ok := m.models[fullName]; !ok {
return nil, fmt.Errorf("model %s not found: %w", fullName, errors.ErrNotFound)
}

addrs := []string{}
for _, addr := range m.assignment[fullName] {
addrs = append(addrs, string(addr))
}
return addrs, nil
}

// Join lets one model server join from an address.
func (m *Mgr) Join(ctx context.Context, addr, debugAddr string, specs *apb.ModelServer) error {
maddr := modeletAddr(addr)
Expand Down
18 changes: 4 additions & 14 deletions saxml/bin/saxutil_cmd_admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,22 +311,12 @@ func (c *ListCmd) handleSaxModel(ctx context.Context, modelFullName naming.Model
model := publishedModel.GetModel()

// Print out list results in tables.
numActiveReplicas := int(publishedModel.GetNumActiveReplicas())
if numActiveReplicas == 0 {
// Maintain compatibility with old admin server binaries.
// TODO(jiawenhao): Remove when most/all admin servers are new.
numActiveReplicas = len(publishedModel.GetModeletAddresses())
}
if c.modelDetails {
// Extra logic: display one random address if there are multiple.
randomSelectedAddress := randomSelectAddress(publishedModel.GetModeletAddresses())
table := NewResultRenderer(os.Stdout, c.outputCsv)
table.SetHeader([]string{"Model", "Model Path", "Checkpoint Path", "Max Replicas", "Active Replicas"})
table.Append([]string{
modelFullName.ModelName(),
model.GetModelPath(),
model.GetCheckpointPath(),
strconv.Itoa(int(model.GetRequestedNumReplicas())),
strconv.Itoa(numActiveReplicas),
})
table.SetHeader([]string{"Model", "Model Path", "Checkpoint Path", "# of Replicas", "(Selected) ReplicaAddress"})
table.Append([]string{modelFullName.ModelName(), model.GetModelPath(), model.GetCheckpointPath(), strconv.Itoa(len(publishedModel.GetModeletAddresses())), randomSelectedAddress})
table.Render()
}

Expand Down
7 changes: 1 addition & 6 deletions saxml/client/cc/sax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,12 +1024,7 @@ absl::Status List(const AdminOptions& options, absl::string_view id,
model->model = one_model.model_path();
model->ckpt = one_model.checkpoint_path();
model->max_replicas = one_model.requested_num_replicas();
model->active_replicas = pub_model.num_active_replicas();
if (model->active_replicas == 0) {
// Maintain compatibility with old admin server binaries.
// TODO(jiawenhao): Remove when most/all admin servers are new.
model->active_replicas = pub_model.modelet_addresses_size();
}
model->active_replicas = pub_model.modelet_addresses_size();
model->overrides = std::map<std::string, std::string>(
one_model.overrides().begin(), one_model.overrides().end());

Expand Down
6 changes: 3 additions & 3 deletions saxml/client/go/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ type Admin struct {
client pbgrpc.AdminClient

// addrs maintains an addrReplica for every model seen by this
// admin through FindAddresses(). Each addrReplica is the set of
// admin through FindAdddress(). Each addrReplica is the set of
// model server addresses for the model. The set is lazily
// replicated from the admin server through WatchAddresses().
addrs map[string]*addrReplica
}

// TODO(zhifengc): Consider abstracting out module providing a
// resettable sync.Once interface, which can be tested separately.
// TODO(zhifengc): consider abstracting out module providing a
// resettable sync.Once interface, which can be tested separatedly.
func (a *Admin) getAdminClient(ctx context.Context) (pbgrpc.AdminClient, error) {
a.mu.Lock()
// A quick check if a.client is established already.
Expand Down
2 changes: 1 addition & 1 deletion saxml/common/platform/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func modelTableRows(models []*pb.PublishedModel) []tmplModelTableRow {
path := model.GetModel().GetModelPath()
ckpt := model.GetModel().GetCheckpointPath()
requested := fmt.Sprintf("%v", model.GetModel().GetRequestedNumReplicas())
assigned := fmt.Sprintf("%v", model.GetNumActiveReplicas())
assigned := fmt.Sprintf("%v", len(model.GetModeletAddresses()))
items = append(items, tmplModelTableRow{id, requested, assigned, path, ckpt})
}
return items
Expand Down
2 changes: 1 addition & 1 deletion saxml/common/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *stubAdminServer) List(ctx context.Context, in *apb.ListRequest) (*apb.L
RequestedNumReplicas: 1,
Overrides: map[string]string{"foo": "bar"},
},
NumActiveReplicas: int32(len(addresses)),
ModeletAddresses: []string{addresses[0]},
},
},
}
Expand Down
3 changes: 1 addition & 2 deletions saxml/protobuf/admin.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ message Model {
// The state of a published model.
message PublishedModel {
Model model = 1;
repeated string modelet_addresses = 2 [deprecated = true];
int32 num_active_replicas = 3;
repeated string modelet_addresses = 2;
}

// The capabilities of a model server.
Expand Down

0 comments on commit 6ce5b07

Please sign in to comment.