From 4839bdbc7e59f846000e3e2ddfd5771255203353 Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 29 Nov 2023 14:45:09 -0700 Subject: [PATCH] add file finder functions --- main_iterative.py | 15 ++-------- phaseSpaceSampling/utils/dataUtils.py | 16 ++-------- phaseSpaceSampling/utils/fileFinder.py | 41 ++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 27 deletions(-) create mode 100644 phaseSpaceSampling/utils/fileFinder.py diff --git a/main_iterative.py b/main_iterative.py index d5aa59b..8700b8b 100644 --- a/main_iterative.py +++ b/main_iterative.py @@ -9,8 +9,8 @@ import phaseSpaceSampling.sampler as sampler import phaseSpaceSampling.utils.parallel as par -from phaseSpaceSampling import DATA_DIR, INPUT_DIR from phaseSpaceSampling.utils.dataUtils import prepareData +from phaseSpaceSampling.utils.fileFinder import find_input from phaseSpaceSampling.utils.plotFun import * from phaseSpaceSampling.utils.torchutils import get_num_parameters @@ -33,18 +33,7 @@ # ~~~~ Parse input # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -inpt_file = args.input -if not os.path.isfile(inpt_file): - new_inpt_file = os.path.join(INPUT_DIR, os.path.split(inpt_file)[-1]) - par.printRoot(f"WARNING: {inpt_file} not found trying {new_inpt_file} ...") - if not os.path.isfile(new_inpt_file): - par.printRoot( - f"ERROR: could not open data {inpt_file} or {new_inpt_file}" - ) - sys.exit() - else: - inpt_file = new_inpt_file - +inpt_file = find_input(args.input) inpt = parse_input_file(inpt_file) use_normalizing_flow = inpt["pdf_method"].lower() == "normalizingflow" use_bins = inpt["pdf_method"].lower() == "bins" diff --git a/phaseSpaceSampling/utils/dataUtils.py b/phaseSpaceSampling/utils/dataUtils.py index be1596c..7b78f00 100644 --- a/phaseSpaceSampling/utils/dataUtils.py +++ b/phaseSpaceSampling/utils/dataUtils.py @@ -4,7 +4,7 @@ import numpy as np import phaseSpaceSampling.utils.parallel as par -from phaseSpaceSampling import DATA_DIR +from phaseSpaceSampling.utils.fileFinder import find_data # from memory_profiler import profile @@ -35,19 +35,7 @@ def checkData(shape, N, d, nWorkingData, nWorkingDataAdjustment, useNF): # @profile def prepareData(inpt): # Set parameters from input - dataFile = inpt["dataFile"] - if not os.path.isfile(dataFile): - new_dataFile = os.path.join(DATA_DIR, os.path.split(dataFile)[-1]) - par.printRoot( - f"WARNING: {dataFile} not found trying {new_dataFile} ..." - ) - if not os.path.isfile(new_dataFile): - par.printRoot( - f"ERROR: could not open data {dataFile} or {new_dataFile}" - ) - sys.exit() - else: - dataFile = new_dataFile + dataFile = find_data(inpt["dataFile"]) preShuffled = inpt["preShuffled"] == "True" scalerFile = inpt["scalerFile"] nWorkingDatas = [int(float(n)) for n in inpt["nWorkingData"].split()] diff --git a/phaseSpaceSampling/utils/fileFinder.py b/phaseSpaceSampling/utils/fileFinder.py new file mode 100644 index 0000000..26b1bb1 --- /dev/null +++ b/phaseSpaceSampling/utils/fileFinder.py @@ -0,0 +1,41 @@ +import os +import sys + +import phaseSpaceSampling.utils.parallel as par +from phaseSpaceSampling import DATA_DIR, INPUT_DIR + + +def find_input(inpt_file): + if not os.path.isfile(inpt_file): + new_inpt_file = os.path.join(INPUT_DIR, os.path.split(inpt_file)[-1]) + par.printRoot( + f"WARNING: {inpt_file} not found trying {new_inpt_file} ..." + ) + if not os.path.isfile(new_inpt_file): + par.printRoot( + f"ERROR: could not open data {inpt_file} or {new_inpt_file}" + ) + sys.exit() + return None + else: + return new_inpt_file + else: + return inpt_file + + +def find_data(data_file): + if not os.path.isfile(data_file): + new_data_file = os.path.join(DATA_DIR, os.path.split(data_file)[-1]) + par.printRoot( + f"WARNING: {data_file} not found trying {new_data_file} ..." + ) + if not os.path.isfile(new_data_file): + par.printRoot( + f"ERROR: could not open data {data_file} or {new_data_file}" + ) + sys.exit() + return None + else: + return new_data_file + else: + return data_file