From 8dc344bfcc0e7e40b503f97bfa52b0f1a75badf5 Mon Sep 17 00:00:00 2001 From: drv850 Date: Mon, 4 Nov 2024 18:28:08 +0100 Subject: [PATCH] update test --- GenNet_utils/Interpret.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/GenNet_utils/Interpret.py b/GenNet_utils/Interpret.py index 3f52018..1e414fa 100644 --- a/GenNet_utils/Interpret.py +++ b/GenNet_utils/Interpret.py @@ -80,10 +80,11 @@ def get_DeepExplainer_scores(args): xval = xval[0] xtest = xtest[0] - + yval = yval.flatten() + ytest = ytest.flatten() - xval = xval if args.regression else xval[yval==0] - xtest = xtest if args.regression else xtest[ytest==1] + xval = xval if args.regression else xval[yval==0,:] + xtest = xtest if args.regression else xtest[ytest==1,:] explainer = shap.DeepExplainer((model.input, model.output), ) print("Created explainer") @@ -168,23 +169,17 @@ def get_DFIM_scores(args): xval = xval[0] xtest = xtest[0] + yval = yval.flatten() + ytest = ytest.flatten() - print("xval shape", xval.shape) - print("yval shape", yval.shape) - - - print("xtest shape", xtest.shape) - print("ytestl shape", ytest.shape) - - if np.unique(np.array(ytest)).shape[0] > 2: args.regression = True else: args.regression = False - xval = xval if args.regression else xval[yval==0] - xtest = xtest if args.regression else xtest[ytest==1] + xval = xval if args.regression else xval[yval==0,:] + xtest = xtest if args.regression else xtest[ytest==1,:] explainer = shap.DeepExplainer((model.input, model.output), xval) print("Created explainer") @@ -246,8 +241,7 @@ def get_pathexplain_scores(args): yval = yval.flatten() ytest = ytest.flatten() - - print("Shapes",xval.shape, xtest.shape) + if np.unique(np.array(ytest)).shape[0] > 2: args.regression = True