diff --git a/examples/score_testing_example.ipynb b/examples/score_testing_example.ipynb index cfc172e7..05a76e66 100644 --- a/examples/score_testing_example.ipynb +++ b/examples/score_testing_example.ipynb @@ -4,38 +4,618 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "***\n", - "# Implementing Score Testing\n", + "# Model Lifecycle : Build, Import, and Score Test Decision Tree Classifier Models\n", "\n", - "_**Note:** Before running this example, you will need to run an example that creates a model on a Viya server and copy the UUID. This will be used as a create_score_definition function parameter._" + "This notebook provides an example of implementing the entire model lifecycle using the HMEQ data set. Lines of code that must be modified by the user, such as directory paths or the host server are noted with the comment \"_Changes required by user._\".\n", + "\n", + "_**Note:** If you download only this notebook and not the rest of the repository, you must also download the hmeq.csv file and the HMEQPERF_1_Q1.csv file from the data folder in the examples directory. These files are used when executing this notebook example._\n", + "\n", + "_**Note:** This example has the option of utilizing CAS Gateway to run score testing quickly. This option is available for SAS Viya 2025.01 or later and if the user replaces False with True in the section noted by \"_Change to True if your Viya version is compatible with CAS Gateway_\"._\n", + "\n", + "Here are the steps shown in this notebook:\n", + "\n", + "1. Import, review, and preprocess HMEQ data for model training.\n", + "2. Build, train, and assess a Decision Tree Classifer Model.\n", + "3. Score the model and save the resulting JSON information.\n", + "4. Import the model and the associated JSON files into SAS Model Manager.\n", + "4. Score test the results either with or without CAS Gateway and display the results." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Python Package Imports" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "import requests\n", + "# Standard Library\n", + "from pathlib import Path\n", + "import warnings\n", + "from requests import HTTPError\n", "import sys\n", "\n", - "from sasctl._services.score_definitions import ScoreDefinitions as sd # Importing ScoreDefinitions service\n", - "from sasctl._services.score_execution import ScoreExecution as se # Importing ScoreExecution service" + "# Third Party\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n", + "from sklearn.metrics import classification_report, confusion_matrix\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "\n", + "# Application Specific\n", + "import sasctl.pzmm as pzmm\n", + "from sasctl import Session\n", + "from sasctl._services.score_definitions import ScoreDefinitions as sd\n", + "from sasctl._services.score_execution import ScoreExecution as se" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Global Package Options\n", + "pd.options.mode.chained_assignment = None # default=\"warn\"\n", + "plt.rc(\"font\", size=14)\n", + "# Ignore warnings from pandas about SWAT using a feature that will be depreciated soon\n", + "warnings.simplefilter(action=\"ignore\", category=FutureWarning)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import and Review Data Set" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BADLOANMORTDUEVALUEREASONJOBYOJDEROGDELINQCLAGENINQCLNODEBTINC
002680046236.062711.0DebtConOffice17.00.00.0175.0750581.022.033.059934
102690074982.0126972.0DebtConOffice0.00.00.0315.8189110.023.038.325990
202690067144.092923.0DebtConOther16.00.00.089.1121731.017.032.791478
302690045763.073797.0DebtConOther23.0NaN0.0291.5916811.029.039.370858
4027000144901.0178093.0DebtConProfExe7.00.00.0331.1139720.034.040.566552
\n", + "
" + ], + "text/plain": [ + " BAD LOAN MORTDUE VALUE REASON JOB YOJ DEROG DELINQ \\\n", + "0 0 26800 46236.0 62711.0 DebtCon Office 17.0 0.0 0.0 \n", + "1 0 26900 74982.0 126972.0 DebtCon Office 0.0 0.0 0.0 \n", + "2 0 26900 67144.0 92923.0 DebtCon Other 16.0 0.0 0.0 \n", + "3 0 26900 45763.0 73797.0 DebtCon Other 23.0 NaN 0.0 \n", + "4 0 27000 144901.0 178093.0 DebtCon ProfExe 7.0 0.0 0.0 \n", + "\n", + " CLAGE NINQ CLNO DEBTINC \n", + "0 175.075058 1.0 22.0 33.059934 \n", + "1 315.818911 0.0 23.0 38.325990 \n", + "2 89.112173 1.0 17.0 32.791478 \n", + "3 291.591681 1.0 29.0 39.370858 \n", + "4 331.113972 0.0 34.0 40.566552 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hmeq_data = pd.read_csv(\"data/hmeq.csv\", sep= \",\") # Try \"data/hmeq.csv\" if this does not work\n", + "hmeq_data.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preprocess Data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "predictor_columns = [\"LOAN\", \"MORTDUE\", \"VALUE\", \"YOJ\", \"DEROG\", \"DELINQ\", \"CLAGE\", \"NINQ\", \"CLNO\", \"DEBTINC\"]\n", + "\n", + "target_column = \"BAD\"\n", + "x = hmeq_data[predictor_columns]\n", + "y = hmeq_data[target_column]\n", + "\n", + "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# For missing values, impute the data set's mean value\n", + "x_test.fillna(x_test.mean(), inplace=True)\n", + "x_train.fillna(x_train.mean(), inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create, Train, and Assess Model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "dtc = DecisionTreeClassifier(max_depth=7, min_samples_split=2, min_samples_leaf=2, max_leaf_nodes=500)\n", + "dtc = dtc.fit(x_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate the importance of a predictor \n", + "def sort_feature_importance(model, data):\n", + " features = {}\n", + " for importance, name in sorted(zip(model.feature_importances_, data.columns), reverse=True):\n", + " features[name] = str(np.round(importance*100, 2)) + \"%\"\n", + " return features" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DecisionTree
DEBTINC58.35%
DELINQ18.57%
CLAGE8.07%
DEROG4.86%
VALUE3.24%
YOJ2.78%
MORTDUE1.87%
CLNO1.2%
NINQ0.88%
LOAN0.17%
\n", + "
" + ], + "text/plain": [ + " DecisionTree\n", + "DEBTINC 58.35%\n", + "DELINQ 18.57%\n", + "CLAGE 8.07%\n", + "DEROG 4.86%\n", + "VALUE 3.24%\n", + "YOJ 2.78%\n", + "MORTDUE 1.87%\n", + "CLNO 1.2%\n", + "NINQ 0.88%\n", + "LOAN 0.17%" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Displays the percentage weight of the predictors\n", + "importances = pd.DataFrame.from_dict(sort_feature_importance(dtc, x_train), orient=\"index\").rename(columns={0: \"DecisionTree\"})\n", + "importances" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1427 14]\n", + " [ 272 75]]\n", + " precision recall f1-score support\n", + "\n", + " 0 0.84 0.99 0.91 1441\n", + " 1 0.84 0.22 0.34 347\n", + "\n", + " accuracy 0.84 1788\n", + " macro avg 0.84 0.60 0.63 1788\n", + "weighted avg 0.84 0.84 0.80 1788\n", + "\n", + "Decision Tree Model Accuracy = 84.0%\n" + ] + } + ], + "source": [ + "# Displays model score metrics\n", + "y_dtc_predict = dtc.predict(x_test)\n", + "y_dtc_proba = dtc.predict_proba(x_test)\n", + "print(confusion_matrix(y_test, y_dtc_predict))\n", + "print(classification_report(y_test, y_dtc_predict))\n", + "print(\"Decision Tree Model Accuracy = \" + str(np.round(dtc.score(x_test, y_test)*100,2)) + \"%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Register Model in SAS Model Manager with pzmm" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model DecisionTreeClassifier was successfully pickled and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/DecisionTreeClassifier.pickle.\n" + ] + } + ], + "source": [ + "# Output variables expected in SAS Model Manager. If a classification value is expected to be output, it should be the first metric.\n", + "score_metrics = [\"EM_CLASSIFICATION\", \"EM_EVENTPROBABILITY\"]\n", + "\n", + "# Path to where the model should be stored\n", + "path = Path.cwd() / \"data/hmeqModels/DecisionTreeClassifier\"\n", + "\n", + "# Serialize the models to a pickle format\n", + "pzmm.PickleModel.pickle_trained_model(\n", + " model_prefix=\"DecisionTreeClassifier\",\n", + " trained_model=dtc,\n", + " pickle_path=path,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputVar.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/inputVar.json\n", + "outputVar.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/outputVar.json\n", + "ModelProperties.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/ModelProperties.json\n", + "fileMetadata.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/fileMetadata.json\n" + ] + } + ], + "source": [ + "def write_json_files(data, predict, target, path, prefix): \n", + " # Write input variable mapping to a json file\n", + " pzmm.JSONFiles.write_var_json(input_data=data[predict], is_input=True, json_path=path)\n", + " \n", + " # Set output variables and assign an event threshold, then write output variable mapping\n", + " output_var = pd.DataFrame(columns=score_metrics, data=[[\"A\", 0.5]]) # data argument includes example expected types for outputs\n", + " pzmm.JSONFiles.write_var_json(output_var, is_input=False, json_path=path)\n", + " \n", + " # Write model properties to a json file\n", + " pzmm.JSONFiles.write_model_properties_json(\n", + " model_name=prefix, \n", + " target_variable=target, # Target variable to make predictions about (BAD in this case)\n", + " target_values=[\"1\", \"0\"], # Possible values for the target variable (1 or 0 for binary classification of BAD)\n", + " json_path=path, \n", + " model_desc=f\"Description for the {prefix} model.\",\n", + " model_algorithm=\"\",\n", + " modeler=\"sasdemo\",\n", + " )\n", + " \n", + " # Write model metadata to a json file so that SAS Model Manager can properly identify all model files\n", + " pzmm.JSONFiles.write_file_metadata_json(model_prefix=prefix, json_path=path)\n", + "\n", + "\n", + "write_json_files(hmeq_data, predictor_columns, target_column, path, \"DecisionTreeClassifier\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dmcas_fitstat.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/dmcas_fitstat.json\n", + "dmcas_roc.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/dmcas_roc.json\n", + "dmcas_lift.json was successfully written and saved to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/dmcas_lift.json\n" + ] + } + ], + "source": [ + "import getpass\n", + "def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path, prefix):\n", + " # Calculate train predictions\n", + " train_predict = model.predict(x_train)\n", + " train_proba = model.predict_proba(x_train)\n", + " \n", + " # Assign data to lists of actual and predicted values\n", + " train_data = pd.concat([y_train.reset_index(drop=True), pd.Series(train_predict), pd.Series(data=train_proba[:,1])], axis=1)\n", + " test_data = pd.concat([y_test.reset_index(drop=True), pd.Series(test_predict), pd.Series(data=test_proba[:,1])], axis=1)\n", + " \n", + " # Calculate the model statistics, ROC chart, and Lift chart; then write to json files\n", + " pzmm.JSONFiles.calculate_model_statistics(\n", + " target_value=1, \n", + " prob_value=0.5, \n", + " train_data=train_data, \n", + " test_data=test_data, \n", + " json_path=path\n", + " )\n", + "\n", + " full_training_data = pd.concat([y_train.reset_index(drop=True), x_train.reset_index(drop=True)], axis=1)\n", + " \n", + "username = getpass.getpass()\n", + "password = getpass.getpass()\n", + "host = \"demo.sas.com\" # Changes required by user\n", + "sess = Session(host, username, password, protocol=\"http\") # For TLS-enabled servers, change protocol value to \"https\"\n", + "conn = sess.as_swat() # Connect to SWAT through the sasctl authenticated connection\n", + "\n", + "test_predict = y_dtc_predict\n", + "test_proba = y_dtc_proba\n", + "\n", + "write_model_stats(x_train, y_train, test_predict, test_proba, y_test, dtc, path, \"DecisionTreeClassifier\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "UserWarning: Due to the ambiguity of the provided metrics and prediction return types, the score code assumes that a classification and the target event probability should be returned.\n", + " warn(\n", + "UserWarning: No project with the name or UUID HMEQModels was found.\n", + " warn(f\"No project with the name or UUID {project} was found.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model score code was written successfully to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier/score_DecisionTreeClassifier.py and uploaded to SAS Model Manager.\n", + "All model files were zipped to ~/python-sasctl/examples/data/hmeqModels/DecisionTreeClassifier.\n", + "A new project named HMEQModels was created.\n", + "Model was successfully imported into SAS Model Manager as DecisionTreeClassifier with the following UUID: 1481eec5-48a4-4f52-9baa-d44b3f62f9af.\n" + ] + } + ], + "source": [ + "pzmm.ImportModel.import_model(\n", + " model_files=path,\n", + " model_prefix=\"DecisionTreeClassifier\", # What is the model name?\n", + " project=\"HMEQModels\", # What is the project name?\n", + " input_data=x, # What does example input data look like?\n", + " predict_method=[dtc.predict_proba, [int, int]], # What is the predict method and what does it return?\n", + " score_metrics=score_metrics, # What are the output variables?\n", + " overwrite_model=True, # Overwrite the model if it already exists?\n", + " target_values=[\"0\", \"1\"], # What are the expected values of the target variable?\n", + " target_index=1, # What is the index of the target value in target_values?\n", + " model_file_name=\"DecisionTreeClassifier\" + \".pickle\", # How was the model file serialized?\n", + " missing_values=True # Does the data include missing values?\n", + ")\n", + "# Reinitialize the score_code variable when writing more than one model's score code\n", + "pzmm.ScoreCode.score_code = \"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Implementing Score Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Creating the score definition for this model using the model UUID generated two steps before\n", "score_definition = sd.create_score_definition(\n", - " \"example_score_def_name\", # Name of the score_definition, which can be any string\n", - " \"model_id\", # Use Model UUID generated two steps before\n", - " \"table_name\", # Table name for input data, which must exist in host server or it will throw an HTTP error and prompt you to upload a data file\n", - " # True, # Uncomment 'True' if your Viya version is compatible with CAS Gateway\n", - ")" + " score_def_name=\"example_score_def_name\", # Name of the score_definition, which can be any string\n", + " model='DecisionTreeClassifier', # Can use model name, UUID, or dictionary representation of the model\n", + " table_name=\"HMEQPERF_1_Q1\", # Table name for input data\n", + " use_cas_gateway=False, # Change to True if your Viya version is compatible with CAS Gateway. \n", + " table_file='data/HMEQPERF_1_Q1.csv' # add the file path of HMEQPERF_1_Q1 if HMEQPERF_1_Q1 does not yet exist on the server. If the user doesn't need the file path argument, they can comment out this line completely.\n", + ")\n" ] }, { @@ -46,25 +626,84 @@ "source": [ "# Executing the score definition\n", "score_execution = se.create_score_execution(\n", - " score_definition.get(\"id\") # Score definition id created in the previous cell\n", + " score_definition.get(\"id\"), # Score definition id created in the previous cell\n", ")\n", "\n", "# Prints score_execution_id\n", "print(score_execution)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use this function call to wait until the score execution is finished, as get_score_execution_results will throw an error if it hasn't finished\n", + "se.poll_score_execution_state(score_execution)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The following lines print the output table with scoring results. Ensure that the use_cas_gateway argument is the same as it is in the score definition call.\n", + "score_results = se.get_score_execution_results(score_execution, use_cas_gateway=False)\n", + "score_results" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can view our scored model information within Model Manager under Projects -> Choose your model -> Scoring. \n", "***" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using the Score Testing Task\n", + "\n", + "The above commands can be run in a single function call, found in the tasks module of sasctl." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sasctl.tasks import score_model_with_cas\n", + "\n", + "score_model_with_cas(\n", + " score_def_name=\"score_definition_example\",\n", + " model='DecisionTreeClassifier',\n", + " table_name='HMEQPERF_1_Q1', # If this call is made before running the code above, the table_file argument must be included if the file is not yet on the server\n", + " use_cas_gateway=False # Change to True if your Viya version is compatible with CAS Gateway. \n", + ")" + ] } ], "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } }, "nbformat": 4, diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index 6146878e..448a28c5 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -83,15 +83,18 @@ def create_score_definition( else: object_descriptor_type = "sas.models.model.ds2" - model = cls._model_repository.get_model(model) + if cls._model_repository.is_uuid(model): + model_id = model + elif isinstance(model, dict) and "id" in model: + model_id = model["id"] + else: + model = cls._model_repository.get_model(model) + if not model: + raise HTTPError( + "This model may not exist in a project or the model may not exist at all." + ) + model_id = model["id"] - if not model: - raise HTTPError( - { - f"This model may not exist in a project or the model may not exist at all." - } - ) - model_id = model.id model_project_id = model.get("projectId") model_project_version_id = model.get("projectVersionId") model_name = model.get("name") diff --git a/src/sasctl/_services/score_execution.py b/src/sasctl/_services/score_execution.py index eeac893e..5d464da8 100644 --- a/src/sasctl/_services/score_execution.py +++ b/src/sasctl/_services/score_execution.py @@ -147,8 +147,7 @@ def poll_score_execution_state( @classmethod def get_score_execution_results( - cls, - score_execution: Union[dict, str], + cls, score_execution: Union[dict, str], use_cas_gateway: False ): """Generates an output table for the score_execution results. @@ -183,13 +182,13 @@ def get_score_execution_results( else: session = current_session() cas = session.as_swat() - response = cas.loadActionSet("gateway") - if not response: + if not use_cas_gateway: output_table = cls._no_gateway_get_results( server_name, library_name, table_name ) return output_table else: + cas.loadActionSet("gateway") gateway_code = f""" import pandas as pd import numpy as np @@ -235,12 +234,14 @@ def _no_gateway_get_results(cls, server_name, library_name, table_name): f"caslibs/{library_name}/" f"tables/{table_name}/columns?limit=10000" ) - columns = json_normalize(output_columns.json(), "items") - column_names = columns["names"].to_list() + columns = json_normalize(output_columns) + column_names = columns["name"].to_list() + + session = current_session() - output_rows = cls._services.get( - f"casRowSets/servers/{server_name}" - f"caslibs/{library_name}" + output_rows = session.get( + f"casRowSets/servers/{server_name}/" + f"caslibs/{library_name}/" f"tables/{table_name}/rows?limit=10000" ) output_table = pd.DataFrame( diff --git a/src/sasctl/tasks.py b/src/sasctl/tasks.py index 1da7d3fb..d466c10f 100644 --- a/src/sasctl/tasks.py +++ b/src/sasctl/tasks.py @@ -996,5 +996,5 @@ def score_model_with_cas( score_execution = se.create_score_execution(score_definition.id) score_execution_poll = se.poll_score_execution_state(score_execution) print(score_execution_poll) - score_results = se.get_score_execution_results(score_execution) + score_results = se.get_score_execution_results(score_execution, use_cas_gateway) return score_results diff --git a/tests/unit/test_score_definitions.py b/tests/unit/test_score_definitions.py index 8a75294a..d1210866 100644 --- a/tests/unit/test_score_definitions.py +++ b/tests/unit/test_score_definitions.py @@ -73,16 +73,13 @@ def test_create_score_definition(): model="12345", table_name="test_table", ) - # Valid model id but invalid table name with no table_file argument test case - get_model_mock = CustomMock( - json_info={ - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - }, - ) + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + } get_model.return_value = get_model_mock get_table.return_value = None with pytest.raises(HTTPError): @@ -107,9 +104,7 @@ def test_create_score_definition(): # Valid table_file argument that successfully creates a table test case get_table.return_value = None upload_file.return_value = RestObj - get_table_mock = CustomMock( - json_info={"tableName": "test_table"}, - ) + get_table_mock = {"tableName": "test_table"} get_table.return_value = get_table_mock response = sd.create_score_definition( score_def_name="test_create_sd", @@ -130,19 +125,17 @@ def test_create_score_definition(): assert response # Checking response with inputVariables in model elements - get_model_mock = CustomMock( - json_info={ - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - "inputVariables": [ - {"name": "first"}, - {"name": "second"}, - {"name": "third"}, - ], - }, - ) + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + "inputVariables": [ + {"name": "first"}, + {"name": "second"}, + {"name": "third"}, + ], + } get_model.return_value = get_model_mock get_table.return_value = get_table_mock response = sd.create_score_definition(