Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
malihass committed Nov 30, 2023
1 parent 1577b54 commit 837f9b0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse

from phaseSpaceSampling.wrapper import downsample_dataset_from_input

parser = argparse.ArgumentParser(description="Downsampler")
Expand Down
8 changes: 6 additions & 2 deletions phaseSpaceSampling/utils/fileFinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

def find_input(inpt_file):
if not os.path.isfile(inpt_file):
new_inpt_file = os.path.join(PSS_INPUT_DIR, os.path.split(inpt_file)[-1])
new_inpt_file = os.path.join(
PSS_INPUT_DIR, os.path.split(inpt_file)[-1]
)
par.printRoot(
f"WARNING: {inpt_file} not found trying {new_inpt_file} ..."
)
Expand All @@ -25,7 +27,9 @@ def find_input(inpt_file):

def find_data(data_file):
if not os.path.isfile(data_file):
new_data_file = os.path.join(PSS_DATA_DIR, os.path.split(data_file)[-1])
new_data_file = os.path.join(
PSS_DATA_DIR, os.path.split(data_file)[-1]
)
par.printRoot(
f"WARNING: {data_file} not found trying {new_data_file} ..."
)
Expand Down
49 changes: 24 additions & 25 deletions phaseSpaceSampling/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


def downsample_dataset_from_input(inpt_file):

inpt_file = find_input(inpt_file)
inpt = parse_input_file(inpt_file)
use_normalizing_flow = inpt["pdf_method"].lower() == "normalizingflow"
Expand All @@ -24,7 +23,7 @@ def downsample_dataset_from_input(inpt_file):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~ Parameters to save
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# List of sample size
nSamples = [int(float(n)) for n in inpt["nSamples"].split()]
# Data size used to adjust the sampling probability
Expand All @@ -46,12 +45,11 @@ def downsample_dataset_from_input(inpt_file):
nSampleCriterionLimit = int(inpt["nSampleCriterionLimit"])
except:
nSampleCriterionLimit = int(1e5)



# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~ Environment
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
use_gpu = (
(inpt["use_gpu"] == "True")
Expand All @@ -69,16 +67,15 @@ def downsample_dataset_from_input(inpt_file):
# REPRODUCIBILITY
torch.manual_seed(int(inpt["seed"]) + par.irank)
np.random.seed(int(inpt["seed"]) + par.irank)



# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~ Prepare Data and scatter across processors
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

data_to_downsample_, dataInd_, working_data, nFullData = prepareData(inpt)

dim = data_to_downsample_.shape[1]

# Compute uniform sampling criterion of random data
randomCriterion = np.zeros(len(nSamples))
if par.irank == par.iroot and computeCriterion:
Expand All @@ -93,39 +90,41 @@ def downsample_dataset_from_input(inpt_file):
par.printRoot(
f"\t nSample {nSample} mean dist = {mean:.4f}, std dist = {std:.4f}"
)

# Prepare arrays used for sanity checks
meanCriterion = np.zeros((int(inpt["num_pdf_iter"]), len(nSamples)))
stdCriterion = np.zeros((int(inpt["num_pdf_iter"]), len(nSamples)))
flow_nll_loss = np.zeros(int(inpt["num_pdf_iter"]))

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~ Downsample
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

data_for_pdf_est = working_data

for pdf_iter in range(int(inpt["num_pdf_iter"])):
if use_normalizing_flow:
# Create the normalizing flow
flow = sampler.createFlow(dim, pdf_iter, inpt)
# flow = flow.to(device)
n_params = get_num_parameters(flow)
par.printRoot(
"There are {} trainable parameters in this model.".format(n_params)
"There are {} trainable parameters in this model.".format(
n_params
)
)

# Train (happens on 1 proc)
flow_nll_loss[pdf_iter] = sampler.trainFlow(
data_for_pdf_est, flow, pdf_iter, inpt
)
sampler.checkLoss(pdf_iter, flow_nll_loss)

# Evaluate probability: This is the expensive step (happens on multi processors)
log_density_np_ = sampler.evalLogProbNF(
flow, data_to_downsample_, pdf_iter, inpt
)

if use_bins:
bin_pdfH, bin_pdfEdges = sampler.trainBinPDF(
data_for_pdf_est, pdf_iter, inpt
Expand All @@ -134,14 +133,14 @@ def downsample_dataset_from_input(inpt_file):
log_density_np_ = sampler.evalLogProbBIN(
data_to_downsample_, pdf_iter, inpt
)

if use_serial_adjustment:
log_density_np_for_adjust = par.gatherNelementsInArray(
log_density_np_, nWorkingDataAdjustment
)
else:
log_density_np_for_adjust = None

# Correct probability estimate
if pdf_iter > 0:
log_density_np_ = log_density_np_ - log_samplingProb_
Expand All @@ -151,9 +150,9 @@ def downsample_dataset_from_input(inpt_file):
)
else:
log_density_np_for_adjust = None

par.printRoot(f"TRAIN ITER {pdf_iter}")

for inSample, nSample in enumerate(nSamples):
# Downsample
(
Expand All @@ -170,7 +169,7 @@ def downsample_dataset_from_input(inpt_file):
nFullData,
inpt,
)

# Plot
# cornerPlotScatter(downSampledData,title='downSampled npts='+str(nSample)+', iter='+str(pdf_iter))
# Get criterion
Expand All @@ -184,7 +183,7 @@ def downsample_dataset_from_input(inpt_file):
par.printRoot(
f"\t nSample {nSample} mean dist = {mean:.4f}, std dist = {std:.4f}"
)

if pdf_iter == int(inpt["num_pdf_iter"]) - 1:
# Last pdf iter : Root proc saves downsampled data, and checks the outcome
sampler.checkProcedure(
Expand All @@ -198,7 +197,7 @@ def downsample_dataset_from_input(inpt_file):
data=downSampledData,
indices=downSampledIndices,
)

if not (pdf_iter == int(inpt["num_pdf_iter"]) - 1):
# Prepare data for the next training iteration
(
Expand Down

0 comments on commit 837f9b0

Please sign in to comment.