diff --git a/deckard/layers/afr.py b/deckard/layers/afr.py index 67e567e6..21c4a816 100644 --- a/deckard/layers/afr.py +++ b/deckard/layers/afr.py @@ -323,10 +323,7 @@ def split_data_for_aft( }, } - weibull_layers = plot_partial_effects( - aft=wft, - **weibull_partial_dict_layers, - ) + weibull_layers = plot_partial_effects(aft=wft, **weibull_partial_dict_layers) wft_scores = score_model(wft, X_train, X_test) cox_replacement_dict = { @@ -364,16 +361,9 @@ def split_data_for_aft( "mtype": "cox", "replacement_dict": cox_replacement_dict, } - cox_afr, cft = plot_aft( - df=X_train, - event_col=target, - **cox_plot_dict, - ) + cox_afr, cft = plot_aft(df=X_train, event_col=target, **cox_plot_dict) cox_scores = score_model(cft, X_train, X_test) - cox_partial = plot_partial_effects( - aft=cft, - **cox_partial_dict, - ) + cox_partial = plot_partial_effects(aft=cft, **cox_partial_dict) log_normal_dict = { "Intercept: sigma_": "$\sigma$", # noqa w605