Skip to content

Commit

Permalink
add file finder functions
Browse files Browse the repository at this point in the history
  • Loading branch information
malihass committed Nov 29, 2023
1 parent c4303e9 commit 4839bdb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
15 changes: 2 additions & 13 deletions main_iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import phaseSpaceSampling.sampler as sampler
import phaseSpaceSampling.utils.parallel as par
from phaseSpaceSampling import DATA_DIR, INPUT_DIR
from phaseSpaceSampling.utils.dataUtils import prepareData
from phaseSpaceSampling.utils.fileFinder import find_input
from phaseSpaceSampling.utils.plotFun import *
from phaseSpaceSampling.utils.torchutils import get_num_parameters

Expand All @@ -33,18 +33,7 @@
# ~~~~ Parse input
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

inpt_file = args.input
if not os.path.isfile(inpt_file):
new_inpt_file = os.path.join(INPUT_DIR, os.path.split(inpt_file)[-1])
par.printRoot(f"WARNING: {inpt_file} not found trying {new_inpt_file} ...")
if not os.path.isfile(new_inpt_file):
par.printRoot(
f"ERROR: could not open data {inpt_file} or {new_inpt_file}"
)
sys.exit()
else:
inpt_file = new_inpt_file

inpt_file = find_input(args.input)
inpt = parse_input_file(inpt_file)
use_normalizing_flow = inpt["pdf_method"].lower() == "normalizingflow"
use_bins = inpt["pdf_method"].lower() == "bins"
Expand Down
16 changes: 2 additions & 14 deletions phaseSpaceSampling/utils/dataUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

import phaseSpaceSampling.utils.parallel as par
from phaseSpaceSampling import DATA_DIR
from phaseSpaceSampling.utils.fileFinder import find_data

# from memory_profiler import profile

Expand Down Expand Up @@ -35,19 +35,7 @@ def checkData(shape, N, d, nWorkingData, nWorkingDataAdjustment, useNF):
# @profile
def prepareData(inpt):
# Set parameters from input
dataFile = inpt["dataFile"]
if not os.path.isfile(dataFile):
new_dataFile = os.path.join(DATA_DIR, os.path.split(dataFile)[-1])
par.printRoot(
f"WARNING: {dataFile} not found trying {new_dataFile} ..."
)
if not os.path.isfile(new_dataFile):
par.printRoot(
f"ERROR: could not open data {dataFile} or {new_dataFile}"
)
sys.exit()
else:
dataFile = new_dataFile
dataFile = find_data(inpt["dataFile"])
preShuffled = inpt["preShuffled"] == "True"
scalerFile = inpt["scalerFile"]
nWorkingDatas = [int(float(n)) for n in inpt["nWorkingData"].split()]
Expand Down
41 changes: 41 additions & 0 deletions phaseSpaceSampling/utils/fileFinder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

import phaseSpaceSampling.utils.parallel as par
from phaseSpaceSampling import DATA_DIR, INPUT_DIR


def find_input(inpt_file):
if not os.path.isfile(inpt_file):
new_inpt_file = os.path.join(INPUT_DIR, os.path.split(inpt_file)[-1])
par.printRoot(
f"WARNING: {inpt_file} not found trying {new_inpt_file} ..."
)
if not os.path.isfile(new_inpt_file):
par.printRoot(
f"ERROR: could not open data {inpt_file} or {new_inpt_file}"
)
sys.exit()
return None
else:
return new_inpt_file
else:
return inpt_file


def find_data(data_file):
if not os.path.isfile(data_file):
new_data_file = os.path.join(DATA_DIR, os.path.split(data_file)[-1])
par.printRoot(
f"WARNING: {data_file} not found trying {new_data_file} ..."
)
if not os.path.isfile(new_data_file):
par.printRoot(
f"ERROR: could not open data {data_file} or {new_data_file}"
)
sys.exit()
return None
else:
return new_data_file
else:
return data_file

0 comments on commit 4839bdb

Please sign in to comment.