Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Min Required Tuning Memory #440

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/klog/v2"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand Down Expand Up @@ -107,7 +108,7 @@ func UnmarshalTrainingConfig(cm *corev1.ConfigMap) (*Config, *apis.FieldError) {
return &config, nil
}

func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap, modelName, methodLowerCase, sku string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
Expand Down Expand Up @@ -136,13 +137,47 @@ func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
}
}

// TODO: Here we perform the tuning GPU Memory Checks!
fmt.Println(trainingArgsRaw)
// Validate GPU Memory Requirements for batch size of 1 using model and tuning method
errs := validateTuningParameters(modelName, methodLowerCase, sku)
if errs != nil {
return errs
}
}
}
return nil
}

func validateTuningParameters(modelName, methodLowerCase, sku string) *apis.FieldError {
skuHandler, err := utils.GetSKUHandler()
if err != nil {
return apis.ErrInvalidValue(fmt.Sprintf("Failed to get SKU handler: %v", err), "sku")
}

skuConfig, skuExists := skuHandler.GetGPUConfigs()[sku]
if !skuExists {
return apis.ErrInvalidValue(fmt.Sprintf("Unsupported SKU: '%s'", sku), "sku")
}
skuGPUMem := skuConfig.GPUMem

modelTuningConfig, modelExists := modelTuningConfigs[modelName]
if !modelExists {
//klog.Infof("Model '%s' hasn't been tested yet for fine-tuning. Proceed at your own risk.", modelName)
return nil
}

minGPURequired, methodExists := modelTuningConfig[methodLowerCase]
if !methodExists {
//klog.Infof("Tuning method '%s' for model '%s' hasn't been tested yet.", methodLowerCase, modelName)
return nil
}

if skuGPUMem < minGPURequired {
klog.Warningf("Insufficient GPU memory: For model '%s' with tuning method '%s', the SKU '%s' with %dGi GPU memory does not support even a batch size of 1 in testing. Proceed at your own risk.", modelName, methodLowerCase, sku, skuGPUMem)
return nil
}
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
Expand Down Expand Up @@ -249,7 +284,7 @@ func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
return nil
}

func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, methodLowerCase string, configMapName string) (errs *apis.FieldError) {
func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace, methodLowerCase, sku string, configMapName string) (errs *apis.FieldError) {
var cm corev1.ConfigMap
if k8sclient.Client == nil {
errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context"))
Expand All @@ -269,7 +304,10 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
if err := validateTrainingArgsViaConfigMap(&cm); err != nil {

if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if err := validateTrainingArgsViaConfigMap(&cm, string(r.Preset.Name), methodLowerCase, sku); err != nil {
errs = errs.Also(err)
}
}
Expand Down
11 changes: 11 additions & 0 deletions api/v1alpha1/tuning_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package v1alpha1

// Map Representing Minimum Per GPU Memory required for Batch Size of 1
// ModelName, TuningMethod, MinGPUMemory
var modelTuningConfigs = map[string]map[string]int{
"falcon-7b": {
//string(TuningMethodLora): 24,
string(TuningMethodQLora): 16,
},
// Add more configurations as needed
}
8 changes: 4 additions & 4 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
if w.Tuning != nil {
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Resource.validateCreateWithTuning(w.Tuning).ViaField("resource"),
w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
w.Tuning.validateCreate(ctx, w.Namespace, w.Resource.InstanceType).ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
Expand Down Expand Up @@ -131,7 +131,7 @@ func (r *AdapterSpec) validateCreateorUpdate() (errs *apis.FieldError) {
return errs
}

func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string, sku string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
Expand All @@ -148,11 +148,11 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
} else if methodLowerCase == string(TuningMethodQLora) {
defaultConfigMapTemplateName = DefaultQloraConfigMapTemplate
}
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, defaultConfigMapTemplateName); err != nil {
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, sku, defaultConfigMapTemplateName); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.Config); err != nil {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, sku, r.Config); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}
Expand Down
3 changes: 2 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,8 @@ func TestTuningSpecValidateCreate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE")
os.Setenv("CLOUD_PROVIDER", "azure") // Manually set for testing env, normally defined in helm chart
errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE", "Standard_NC6s_v3")
hasErrs := errs != nil

if hasErrs != tt.wantErr {
Expand Down
Loading