diff --git a/deckard/layers/afr.py b/deckard/layers/afr.py index 9d798b9b..90d243c6 100644 --- a/deckard/layers/afr.py +++ b/deckard/layers/afr.py @@ -675,15 +675,17 @@ def clean_data_for_aft( subset = subset.drop(columns=list(dummy_dict.keys())) cleaned = pd.concat([subset, dummies], axis=1) else: - # Find non-numeric columns - non_numeric = subset.select_dtypes(exclude=[np.number]).columns - dummy_subset = subset[non_numeric] - dummies = pd.get_dummies( - dummy_subset, - columns=dummy_subset.columns, - ) - subset = subset.drop(columns=dummy_subset.columns) - cleaned = pd.concat([subset, dummies], axis=1) + # Assume that some categorical variables exist and need to be one-hot encoded + cleaned = subset.copy() + dummy_cols = [] + for col in cleaned.columns: + if cleaned[col].dtype == "object": + dummy_cols.append(col) + dummies = pd.get_dummies(cleaned[dummy_cols], prefix="", prefix_sep="") + cleaned = cleaned.drop(columns=dummy_cols) + cleaned = pd.concat([cleaned, dummies], axis=1) + cleaned = cleaned.astype(float) + cleaned = cleaned.dropna(axis=0, how="any") assert ( target in cleaned.columns ), f"Target {target} not in dataftame with columns {cleaned.columns}"