diff --git a/utils/validate_data.py b/utils/validate_data.py index 402e724..2632eaf 100644 --- a/utils/validate_data.py +++ b/utils/validate_data.py @@ -34,6 +34,8 @@ "open-mistral-7b": 5720, "open-mixtral-8x7b": 2966, "open-mixtral-8x22b": 1007, + "mistral-large-latest": 567, + 'open-mistral-nemo': 3337, } MIN_NUM_JSONL_LINES = 10 @@ -104,7 +106,11 @@ def get_train_stats( model_id = "open-mixtral-8x7b" elif params_config["dim"] == 6144: model_id = "open-mixtral-8x22b" - else: + elif params_config["dim"] == 12288: + model_id = "mistral-large-latest" + elif params_config["dim"] == 5120: + model_id = "open-mistral-nemo" + else: raise ValueError("Provided model folder seems incorrect.") else: model_id = train_args.model_id_or_path