From 3129398cde4077b7af7bb7d79e536471a51d4dfa Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Wed, 11 Sep 2024 19:06:01 +0000 Subject: [PATCH] linting --- deckard/layers/afr.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/deckard/layers/afr.py b/deckard/layers/afr.py index 6984c62f..a3e29623 100644 --- a/deckard/layers/afr.py +++ b/deckard/layers/afr.py @@ -687,8 +687,17 @@ def clean_data_for_aft( subset = subset.drop(columns=list(dummy_dict.keys())) cleaned = pd.concat([subset, dummies], axis=1) else: - cleaned = subset.astype(float) - cleaned = subset.dropna(axis=0, how="any") + # Assume that some categorical variables exist and need to be one-hot encoded + cleaned = subset.copy() + dummy_cols = [] + for col in cleaned.columns: + if cleaned[col].dtype == "object": + dummy_cols.append(col) + dummies = pd.get_dummies(cleaned[dummy_cols], prefix="", prefix_sep="") + cleaned = cleaned.drop(columns=dummy_cols) + cleaned = pd.concat([cleaned, dummies], axis=1) + cleaned = cleaned.astype(float) + cleaned = cleaned.dropna(axis=0, how="any") assert ( target in cleaned.columns ), f"Target {target} not in dataframe with columns {cleaned.columns}"