Skip to content

Commit

Permalink
feat(scheduler): account for number of model instances when scheduling (
Browse files Browse the repository at this point in the history
#6183)

* setting the model runtime info in the scheduler

* linting

* proto cleanup

* improve nil pointer checks
  • Loading branch information
driev authored Jan 8, 2025
1 parent 8090437 commit 616cddc
Show file tree
Hide file tree
Showing 29 changed files with 1,750 additions and 1,669 deletions.
662 changes: 173 additions & 489 deletions apis/go/mlops/agent/agent.pb.go

Large diffs are not rendered by default.

2,309 changes: 1,315 additions & 994 deletions apis/go/mlops/scheduler/scheduler.pb.go

Large diffs are not rendered by default.

23 changes: 1 addition & 22 deletions apis/mlops/agent/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ message ModelEventMessage {
Event event = 5;
string message = 6;
uint64 availableMemoryBytes = 7;
ModelRuntimeInfo runtimeInfo = 8;
scheduler.ModelRuntimeInfo runtimeInfo = 8;
}

message ModelEventResponse {
Expand Down Expand Up @@ -93,29 +93,8 @@ message ModelOperationMessage {
message ModelVersion {
scheduler.Model model = 1;
uint32 version = 2;
ModelRuntimeInfo runtimeInfo = 3;
}

message ModelRuntimeInfo {
oneof modelRuntimeInfo {
MLServerModelSettings mlserver = 1;
TritonModelConfig triton = 2;
}
}

message MLServerModelSettings {
uint32 parallelWorkers = 1;
}

message TritonModelConfig {
repeated TritonCPU cpu = 1;
}

message TritonCPU {
uint32 instanceCount = 1;
}


// [END Messages]

// [START Services]
Expand Down
20 changes: 20 additions & 0 deletions apis/mlops/scheduler/scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ message ModelSpec {
optional string server = 6; // the particular model server to load the model. If unspecified will be chosen.
optional ExplainerSpec explainer = 7; // optional black box explainer details
repeated ParameterSpec parameters = 8; // parameters to load with model
optional ModelRuntimeInfo modelRuntimeInfo = 9; // model specific settings that are sent by the agent
}

message ParameterSpec {
Expand All @@ -58,6 +59,25 @@ message ExplainerSpec {
optional string pipelineRef = 3;
}

message ModelRuntimeInfo {
oneof modelRuntimeInfo {
MLServerModelSettings mlserver = 1;
TritonModelConfig triton = 2;
}
}

message MLServerModelSettings {
uint32 parallelWorkers = 1;
}

message TritonModelConfig {
repeated TritonCPU cpu = 1;
}

message TritonCPU {
uint32 instanceCount = 1;
}

message KubernetesMeta {
string namespace = 1;
int64 generation = 2;
Expand Down
6 changes: 3 additions & 3 deletions scheduler/pkg/agent/agent_debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
Name: "dummy_1_1",
},
ModelSpec: &pbs.ModelSpec{
Uri: "gs://dummy",
MemoryBytes: &mem,
Uri: "gs://dummy",
MemoryBytes: &mem,
ModelRuntimeInfo: getModelRuntimeInfo(1),
},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
)
g.Expect(err).To(BeNil())
Expand Down
4 changes: 3 additions & 1 deletion scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"google.golang.org/protobuf/encoding/protojson"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
"github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"
seldontls "github.com/seldonio/seldon-core/components/tls/v2/pkg/tls"

Expand Down Expand Up @@ -683,6 +684,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// we do not care about model versions here
// model runtime info is retrieved from the existing version, so nil is passed here
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion(), nil)

unloaderFn := func() error {
Expand Down Expand Up @@ -751,7 +753,7 @@ func (c *Client) sendModelEventError(
func (c *Client) sendAgentEvent(
modelName string,
modelVersion uint32,
modelRuntimeInfo *agent.ModelRuntimeInfo,
modelRuntimeInfo *scheduler.ModelRuntimeInfo,
event agent.ModelEventMessage_Event,
) error {
// if the server is draining and the model load has succeeded, we need to "cancel"
Expand Down
26 changes: 8 additions & 18 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

Expand Down Expand Up @@ -66,8 +65,8 @@ func (f *FakeModelRepository) RemoveModelVersion(modelName string) error {
return nil
}

func (f *FakeModelRepository) GetModelRuntimeInfo(modelName string) (*pb.ModelRuntimeInfo, error) {
return &pb.ModelRuntimeInfo{ModelRuntimeInfo: &pb.ModelRuntimeInfo_Mlserver{Mlserver: &agent.MLServerModelSettings{ParallelWorkers: uint32(1)}}}, nil
func (f *FakeModelRepository) GetModelRuntimeInfo(modelName string) (*pbs.ModelRuntimeInfo, error) {
return &pbs.ModelRuntimeInfo{ModelRuntimeInfo: &pbs.ModelRuntimeInfo_Mlserver{Mlserver: &pbs.MLServerModelSettings{ParallelWorkers: uint32(1)}}}, nil
}

func (f *FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
Expand Down Expand Up @@ -254,16 +253,15 @@ func TestLoadModel(t *testing.T) {
models []string
replicaConfig *pb.ReplicaConfig
op *pb.ModelOperationMessage
modelConfig *pb.ModelRuntimeInfo
expectedAvailableMemory uint64
v2Status int
modelRepoErr error
success bool
autoscalingEnabled bool
}

smallMemory := uint64(500)
largeMemory := uint64(2000)
memory500 := uint64(500)
memory2000 := uint64(2000)

tests := []test{
{
Expand All @@ -276,13 +274,11 @@ func TestLoadModel(t *testing.T) {
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &memory500, ModelRuntimeInfo: getModelRuntimeInfo(1)},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -297,14 +293,12 @@ func TestLoadModel(t *testing.T) {
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &memory500, ModelRuntimeInfo: getModelRuntimeInfo(1)},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
AutoscalingEnabled: true,
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -320,13 +314,11 @@ func TestLoadModel(t *testing.T) {
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &memory500, ModelRuntimeInfo: getModelRuntimeInfo(1)},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 400,
success: false,
Expand All @@ -341,13 +333,11 @@ func TestLoadModel(t *testing.T) {
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &largeMemory},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &memory2000, ModelRuntimeInfo: getModelRuntimeInfo(1)},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 200,
success: false,
Expand Down
10 changes: 8 additions & 2 deletions scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"google.golang.org/protobuf/proto"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
"github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/agent/interfaces"
)
Expand Down Expand Up @@ -54,11 +55,16 @@ func isReady(service interfaces.DependencyServiceInterface, logger *log.Entry, m
return backoff.RetryNotify(readyToError, backoffWithMax, logFailure)
}

func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion, modelRuntimeInfo *agent.ModelRuntimeInfo) *agent.ModelVersion {
func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion, modelRuntimeInfo *scheduler.ModelRuntimeInfo) *agent.ModelVersion {
mv := proto.Clone(originalModelVersion)
mv.(*agent.ModelVersion).Model.Meta.Name = modelId
if modelRuntimeInfo != nil && modelRuntimeInfo.ModelRuntimeInfo != nil {
if mv.(*agent.ModelVersion).Model.ModelSpec == nil {
mv.(*agent.ModelVersion).Model.ModelSpec = &scheduler.ModelSpec{}
}
mv.(*agent.ModelVersion).Model.ModelSpec.ModelRuntimeInfo = modelRuntimeInfo
}
mv.(*agent.ModelVersion).Version = version
mv.(*agent.ModelVersion).RuntimeInfo = modelRuntimeInfo
return mv.(*agent.ModelVersion)
}

Expand Down
13 changes: 7 additions & 6 deletions scheduler/pkg/agent/model_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
"github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/util"
)
Expand Down Expand Up @@ -141,7 +142,7 @@ func (modelState *ModelState) getVersionsForAllModels() []*agent.ModelVersion {
mv := version.get()
versionedModelName := mv.Model.GetMeta().Name
originalModelName, originalModelVersion, _ := util.GetOrignalModelNameAndVersion(versionedModelName)
modelRuntimeInfo := mv.RuntimeInfo
modelRuntimeInfo := mv.Model.GetModelSpec().GetModelRuntimeInfo()
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv, modelRuntimeInfo))
}
return loadedModels
Expand All @@ -157,11 +158,11 @@ func (version *modelVersion) getVersionMemory() uint64 {
}

func getInstanceCount(version *modelVersion) uint64 {
switch version.versionInfo.RuntimeInfo.ModelRuntimeInfo.(type) {
case *agent.ModelRuntimeInfo_Mlserver:
return uint64(version.versionInfo.GetRuntimeInfo().GetMlserver().ParallelWorkers)
case *agent.ModelRuntimeInfo_Triton:
return uint64(version.versionInfo.GetRuntimeInfo().GetTriton().Cpu[0].InstanceCount)
switch version.get().GetModel().GetModelSpec().GetModelRuntimeInfo().ModelRuntimeInfo.(type) {
case *scheduler.ModelRuntimeInfo_Mlserver:
return uint64(version.get().GetModel().GetModelSpec().GetModelRuntimeInfo().GetMlserver().ParallelWorkers)
case *scheduler.ModelRuntimeInfo_Triton:
return uint64(version.get().GetModel().GetModelSpec().GetModelRuntimeInfo().GetTriton().Cpu[0].InstanceCount)
default:
return 1
}
Expand Down
Loading

0 comments on commit 616cddc

Please sign in to comment.