Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

45 parallelize point est hist and naive stack #64

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/rail/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
self.fileObj = None
self.groups = None
self.partial = False
self.lenght = None

def open(self, **kwargs):
"""Open and return the associated file
Expand Down Expand Up @@ -309,6 +310,17 @@
if not isinstance(data, qp.Ensemble):
raise TypeError(f"Expected `data` to be a `qp.Ensemble`, but {type(data)} was provided. Perhaps you meant to use `TableHandle`?")

# @classmethod
def _size(cls, path, **kwargs):
if path == 'None':
return cls.data.npdf

Check warning on line 316 in src/rail/core/data.py

View check run for this annotation

Codecov / codecov/patch

src/rail/core/data.py#L316

Added line #L316 was not covered by tests
return tables_io.io.getInputDataLengthHdf5(path, groupname='data')

@classmethod
def _iterator(cls, path, **kwargs):
"""Iterate over the data"""
kwargs.pop('groupname','None')
return qp.iterator(path, **kwargs)

def default_model_read(modelfile):
"""Default function to read model files, simply used pickle.load"""
Expand Down
21 changes: 13 additions & 8 deletions src/rail/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,17 @@ def input_iterator(self, tag, **kwargs):
These will be passed to the Handle's iterator method
"""
handle = self.get_handle(tag, allow_missing=True)
if self.config.hdf5_groupname and handle.path:
self._input_length = handle.size(groupname=self.config.hdf5_groupname)

try:
self.config.hdf5_groupname
except:
self.config.hdf5_groupname = None
self._input_length = handle.size(groupname=self.config.hdf5_groupname)

if handle.path and handle.path!='None':
total_chunks_needed = ceil(self._input_length/self.config.chunk_size)
if total_chunks_needed<self.size: #pragma: no cover
# If the number of process is larger than we need, we wemove some of them
if total_chunks_needed < self.size: #pragma: no cover
color = self.rank+1 <= total_chunks_needed
newcomm = self.comm.Split(color=color,key=self.rank)
if color:
Expand All @@ -342,17 +349,15 @@ def input_iterator(self, tag, **kwargs):
parallel_size=self.size)
kwcopy.update(**kwargs)
return handle.iterator(**kwcopy)
# If data is in memory and not in a file, it means is small enough to process it
# in a single chunk.
# If data is in memory and not in a file, it means is small enough to process it
# in a single chunk.
else: #pragma: no cover
if self.config.hdf5_groupname:
test_data = self.get_data('input')[self.config.hdf5_groupname]
else:
test_data = self.get_data('input')
s = 0
e = len(list(test_data.items())[0][1])
self._input_length=e
iterator=[[s, e, test_data]]
iterator=[[s, self._input_length, test_data]]
return iterator

def connect_input(self, other, inputTag=None, outputTag=None):
Expand Down
62 changes: 51 additions & 11 deletions src/rail/estimation/algos/naive_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,58 @@
self.zgrid = None

def run(self):
rng = np.random.default_rng(seed=self.config.seed)
test_data = self.get_data('input')
iterator = self.input_iterator('input')
self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1)
pdf_vals = test_data.pdf(self.zgrid)
yvals = np.expand_dims(np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.), axis=0), 0)
qp_d = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=yvals))
# Initiallizing the stacking pdf's
yvals = np.zeros((1, len(self.zgrid)))
bvals = np.zeros((self.config.nsamples, len(self.zgrid)))
bootstrap_matrix = self.broadcast_bootstrap_matrix()

first = True
for s, e, test_data in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
self._process_chunk(s, e, test_data, first, bootstrap_matrix, yvals, bvals)
first = False
if self.comm is not None: # pragma: no cover
bvals, yvals = self.join_histograms(bvals, yvals)

if self.rank == 0:
sample_ens = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=bvals))
qp_d = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=yvals))
self.add_data('output', sample_ens)
self.add_data('single_NZ', qp_d)

bvals = np.empty((self.config.nsamples, len(self.zgrid)))

def _process_chunk(self, start, end, data, first, bootstrap_matrix, yvals, bvals):
pdf_vals = data.pdf(self.zgrid)
yvals += np.expand_dims(np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.), axis=0), 0)
# qp_d is the normalized probability of the stack, we need to know how many galaxies were
for i in range(self.config.nsamples):
bootstrap_draws = rng.integers(low=0, high=test_data.npdf, size=test_data.npdf)
bvals[i] = np.sum(pdf_vals[bootstrap_draws], axis=0)
sample_ens = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=bvals))
bootstrap_draws = bootstrap_matrix[:, i]
# Neither all of the bootstrap_draws are in this chunk nor the index starts at "start"
mask = (bootstrap_draws>=start) & (bootstrap_draws<end)
bootstrap_draws = bootstrap_draws[mask] - start
bvals[i] += np.sum(pdf_vals[bootstrap_draws], axis=0)

def broadcast_bootstrap_matrix(self):
rng = np.random.default_rng(seed=self.config.seed)
# Only one of the nodes needs to produce the bootstrap indices
ngal = self._input_length
print('i am the rank with number of galaxies',self.rank,ngal)
if self.rank == 0:
bootstrap_matrix = rng.integers(low=0, high=ngal, size=(ngal,self.config.nsamples))
else: # pragma: no cover
bootstrap_matrix = None
if self.comm is not None: # pragma: no cover
self.comm.Barrier()
bootstrap_matrix = self.comm.bcast(bootstrap_matrix, root = 0)
return bootstrap_matrix

def join_histograms(self, bvals, yvals):
bvals_r = self.comm.reduce(bvals)
yvals_r = self.comm.reduce(yvals)
return(bvals_r, yvals_r)

Check warning on line 97 in src/rail/estimation/algos/naive_stack.py

View check run for this annotation

Codecov / codecov/patch

src/rail/estimation/algos/naive_stack.py#L95-L97

Added lines #L95 - L97 were not covered by tests




self.add_data('output', sample_ens)
self.add_data('single_NZ', qp_d)
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_pq_handle():


def test_qp_handle():
datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.fits")
datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.hdf5")
handle = do_data_handle(datapath, QPHandle)
qpfile = handle.open()
assert qpfile
Expand Down
2 changes: 1 addition & 1 deletion tests/estimation/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

inputdata = os.path.join(RAILDIR, 'rail/examples_data/testdata/output_BPZ_lite.fits')
inputdata = os.path.join(RAILDIR, 'rail/examples_data/testdata/output_BPZ_lite.hdf5')

@pytest.mark.parametrize(
"input_param",
Expand Down
2 changes: 1 addition & 1 deletion tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rail.core.utils import RAILDIR
from rail.estimation.algos import naive_stack, point_est_hist, var_inf

testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.fits")
testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5")
DS = RailStage.data_store


Expand Down
Loading