Skip to content

Commit

Permalink
Merge pull request #208 from sassoftware/score_update
Browse files Browse the repository at this point in the history
Updates to score code and model card example
  • Loading branch information
djm21 authored Dec 24, 2024
2 parents d945855 + 18c7c64 commit b936354
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
16 changes: 15 additions & 1 deletion examples/pzmm_generate_complete_model_card.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,8 @@
" df.columns = df.columns.str.replace(' ', '')\n",
" df.columns = df.columns.str.replace('-', '_')\n",
" df = df.drop(['Sex_Male'], axis=1)\n",
" df = pd.concat([df, cat_vals], axis=1).drop('index', axis=1)\n",
" if 'index' in df.columns or 'index' in cat_vals.columns:\n",
" df = pd.concat([df, cat_vals], axis=1).drop('index', axis=1)\n",
" # For the model to score correctly, all OHE columns must exist\n",
" input_cols = [\n",
" \"Education_9th\", \"Education_10th\", \"Education_11th\", \"Education_12th\", \"Education_Assoc_voc\", \"Education_Assoc_acdm\", \"Education_Masters\", \"Education_Prof_school\",\n",
Expand All @@ -579,9 +580,20 @@
" 'Relationship_Not_in_family', 'Relationship_Own_child', 'Relationship_Unmarried', 'Relationship_Wife', 'Relationship_Other_relative', 'WorkClass_Private',\n",
" 'Education_Bachelors'\n",
" ]\n",
" # OHE columns must be removed after data combination\n",
" predictor_columns = ['Age', 'HoursPerWeek', 'WorkClass_Private', 'WorkClass_Self', 'WorkClass_Gov', \n",
" 'WorkClass_Other', 'Education_HS_grad', 'Education_Some_HS', 'Education_Assoc', 'Education_Some_college',\n",
" 'Education_Bachelors', 'Education_Adv_Degree', 'Education_No_HS', 'MartialStatus_Married_civ_spouse',\n",
" 'MartialStatus_Never_married', 'MartialStatus_Divorced', 'MartialStatus_Separated', 'MartialStatus_Widowed',\n",
" 'MartialStatus_Other', 'Relationship_Husband', 'Relationship_Not_in_family', 'Relationship_Own_child', 'Relationship_Unmarried',\n",
" 'Relationship_Wife', 'Relationship_Other_relative', 'Race_White', 'Race_Black', 'Race_Asian_Pac_Islander',\n",
" 'Race_Amer_Indian_Eskimo', 'Race_Other', 'Sex_Female']\n",
"\n",
" for col in input_cols:\n",
" if col not in df.columns:\n",
" df[col] = 0\n",
" \n",
"\n",
" df[\"Education_Some_HS\"] = df[\"Education_9th\"] | df[\"Education_10th\"] | df[\"Education_11th\"] | df[\"Education_12th\"]\n",
" df[\"Education_Assoc\"] = df[\"Education_Assoc_voc\"] | df[\"Education_Assoc_acdm\"]\n",
" df[\"Education_Adv_Degree\"] = df[\"Education_Masters\"] | df[\"Education_Prof_school\"] | df[\"Education_Doctorate\"]\n",
Expand All @@ -593,6 +605,8 @@
"\n",
" df[\"MartialStatus_Other\"] = df[\"MartialStatus_Married_spouse_absent\"] | df[\"MartialStatus_Married_AF_spouse\"]\n",
"\n",
" df = df[predictor_columns]\n",
"\n",
" return df"
]
},
Expand Down
39 changes: 38 additions & 1 deletion src/sasctl/pzmm/write_score_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _write_imports(
import codecs
binary_string = "<binary string>"
model = pickle.load(codecs.decode(binary_string.encode(), "base64"))
model = pickle.loads(codecs.decode(binary_string.encode(), "base64"))
"""

def _viya35_model_load(
Expand Down Expand Up @@ -562,6 +562,26 @@ def _viya35_model_load(
f'{model_id}/{model_file_name}")))'
)
else:
if pickle_type.lower() == "pickle":
self.score_code += (
f'model_path = Path("/models/resources/viya/{model_id}'
f'")\nwith open(model_path / "{model_file_name}", '
f"\"rb\") as pickle_model:\n{'':4}model = pd.read_pickle"
"(pickle_model)\n\n"
)
"""
model_path = Path("/models/resources/viya/<UUID>")
with open(model_path / "model.pickle", "rb") as pickle_model:
model = pd.read_pickle(pickle_model)
"""
return (
f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}"
f"\")\n{'':8}with open(model_path / \"{model_file_name}\", "
f"\"rb\") as pickle_model:\n{'':12}model = pd.read_pickle"
"(pickle_model)"
)

self.score_code += (
f'model_path = Path("/models/resources/viya/{model_id}'
f'")\nwith open(model_path / "{model_file_name}", '
Expand Down Expand Up @@ -658,6 +678,23 @@ def _viya4_model_load(
f"safe_mode=True)\n"
)
else:
if pickle_type.lower() == "pickle":
self.score_code += (
f"with open(Path(settings.pickle_path) / "
f'"{model_file_name}", "rb") as pickle_model:\n'
f"{'':4}model = pd.read_pickle(pickle_model)\n\n"
)
"""
with open(Path(settings.pickle_path) / "model.pickle", "rb") as pickle_model:
model = pd.read_pickle(pickle_model)
"""
return (
f"{'':8}with open(Path(settings.pickle_path) / "
f'"{model_file_name}", "rb") as pickle_model:\n'
f"{'':12}model = pd.read_pickle(pickle_model)\n\n"
)

self.score_code += (
f"with open(Path(settings.pickle_path) / "
f'"{model_file_name}", "rb") as pickle_model:\n'
Expand Down
18 changes: 14 additions & 4 deletions tests/unit/test_write_score_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ def test_viya35_model_load():
"""
sc = ScoreCode()
load_text = sc._viya35_model_load("1234", "normal")
assert "pickle.load(pickle_model)" in sc.score_code
assert "pickle.load(pickle_model)" in load_text
assert "pd.read_pickle(pickle_model)" in sc.score_code
assert "pd.read_pickle(pickle_model)" in load_text

sc = ScoreCode()
load_text = sc._viya35_model_load("1234", "normal", pickle_type="dill")
assert "dill.load(pickle_model)" in sc.score_code
assert "dill.load(pickle_model)" in load_text

sc = ScoreCode()
mojo_text = sc._viya35_model_load("2345", "mojo", mojo_model=True)
Expand All @@ -142,8 +147,13 @@ def test_viya4_model_load():
"""
sc = ScoreCode()
load_text = sc._viya4_model_load("normal")
assert "pickle.load(pickle_model)" in sc.score_code
assert "pickle.load(pickle_model)" in load_text
assert "pd.read_pickle(pickle_model)" in sc.score_code
assert "pd.read_pickle(pickle_model)" in load_text

sc = ScoreCode()
load_text = sc._viya35_model_load("1234", "normal", pickle_type="dill")
assert "dill.load(pickle_model)" in sc.score_code
assert "dill.load(pickle_model)" in load_text

sc = ScoreCode()
mojo_text = sc._viya4_model_load("mojo", mojo_model=True)
Expand Down

0 comments on commit b936354

Please sign in to comment.