Skip to content

Commit

Permalink
Merge pull request #7 from LSSTDESC/yanza/custom_useful_clusters
Browse files Browse the repository at this point in the history
Yanza/custom useful clusters
  • Loading branch information
yanzastro authored Nov 18, 2024
2 parents e428692 + 6fff3f2 commit b455da8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
57 changes: 38 additions & 19 deletions src/rail/estimation/algos/somoclu_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ class SOMocluInformer(CatInformer):
gridtype=Param(str, 'rectangular', msg="Optional parameter to specify the grid form of the nodes:"
+ "* 'rectangular': rectangular neurons (default)"
+ "* 'hexagonal': hexagonal neurons"),
n_epochs=Param(int, 10, msg="number of training epochs."),
initialization=Param(str, 'pca', msg="method of initializing the SOM:"
+"* 'pca': principal componant analysis (default)"
+"* 'random' randomly initialize the SOM"),
maptype=Param(str, 'planar', msg="Optional parameter to specify the map topology:"
+ "* 'planar': Planar map (default)"
+ "* 'toroid': Toroid map"),
Expand Down Expand Up @@ -220,17 +224,17 @@ def run(self):
if np.isnan(self.config.nondetect_val): # pragma: no cover
mask = np.isnan(training_data[col])
else:
mask = np.isclose(training_data[col], self.config.nondetect_val)
mask = np.logical_or(np.isinf(training_data[col]), np.isclose(training_data[col], self.config.nondetect_val))
training_data[col][mask] = self.config.mag_limits[col]

colors = _computemagcolordata(training_data, self.config.ref_band,
self.config.bands, self.config.column_usage)

som = Somoclu(self.config.n_columns, self.config.n_rows,
gridtype=self.config.gridtype, compactsupport=False,
maptype=self.config.maptype, initialization='pca')
maptype=self.config.maptype, initialization=self.config.initialization)

som.train(colors)
som.train(colors, epochs=self.config.n_epochs,)

modeldict = dict(som=som, usecols=self.config.bands,
ref_column=self.config.ref_band,
Expand Down Expand Up @@ -323,7 +327,9 @@ class SOMocluSummarizer(SZPZSummarizer):
phot_weightcol=Param(str, "", msg="name of photometry weight, if present"),
spec_weightcol=Param(str, "", msg="name of specz weight col, if present"),
split=Param(int, 200, msg="the size of data chunks when calculating the distances between the codebook and data"),
nsamples=Param(int, 20, msg="number of bootstrap samples to generate"))
nsamples=Param(int, 20, msg="number of bootstrap samples to generate"),
useful_clusters=Param(np.ndarray, np.array([]), msg="the cluster indices that are used for calibration. If not given, then "
+"all the clusters containing spec sample are used."),)
outputs = [('output', QPHandle),
('single_NZ', QPHandle),
('cellid_output', Hdf5Handle),
Expand Down Expand Up @@ -400,7 +406,6 @@ def run(self):

self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1)


if self.config.n_clusters > self.n_rows * self.n_columns: # pragma: no cover
print("Warning: number of clusters cannot be greater than the number of cells ("+str(self.n_rows * self.n_columns)+"). The SOM will NOT be grouped into clusters.")
n_clusters = self.n_rows * self.n_columns
Expand Down Expand Up @@ -445,7 +450,7 @@ def run(self):
N_eff_num = 0.
N_eff_den = 0.
phot_cluster_set = set()

bad_clusters = set()
# make dictionary of ID data to be written out with cell IDs
id_dict = {}

Expand All @@ -455,7 +460,7 @@ def run(self):
print(f"Process {self.rank} running summarizer on chunk {s} - {e}")

chunk_number = s//self.config.chunk_size
tmp_neff_num, tmp_neff_den = self._process_chunk(test_data, bootstrap_matrix, som_cluster_inds, spec_cluster_set, phot_cluster_set, sz, spec_data['weight'], spec_som_clusterind, N_eff_p_num, N_eff_p_den, hist_vals, id_dict, s, e, first)
tmp_neff_num, tmp_neff_den = self._process_chunk(test_data, bootstrap_matrix, som_cluster_inds, spec_cluster_set, phot_cluster_set, sz, spec_data['weight'], spec_som_clusterind, N_eff_p_num, N_eff_p_den, hist_vals, id_dict, s, e, first, bad_clusters)
N_eff_num += tmp_neff_num
N_eff_den += tmp_neff_den
first = False
Expand All @@ -474,7 +479,8 @@ def run(self):
hist_vals = self.comm.reduce(hist_vals)
N_eff_num = self.comm.reduce(N_eff_num)
N_eff_den = self.comm.reduce(N_eff_den)

bad_clusters = self.comm.reduce(bad_clusters)

phot_cluster_list=np.array(list(phot_cluster_set),dtype=int)
phot_cluster_total=self.comm.gather(phot_cluster_list)

Expand All @@ -483,12 +489,8 @@ def run(self):
return
phot_cluster_total=np.concatenate(phot_cluster_total)
phot_cluster_set = set(phot_cluster_total)
uncovered_clusters = phot_cluster_set - spec_cluster_set
bad_cluster = dict(uncovered_clusters=np.array(list(uncovered_clusters)))
print("the following clusters contain photometric data but not spectroscopic data:")
print(uncovered_clusters)
useful_clusters = phot_cluster_set - uncovered_clusters
print(f"{len(useful_clusters)} out of {n_clusters} have usable data")

print(f"{len(self.useful_clusters)} out of {n_clusters} have usable data")

# effective number defined in Heymans et al. (2012) to quantify the photometric representation.
# also see Eq.7 in Wright et al. (2020).
Expand All @@ -507,9 +509,9 @@ def run(self):
qp_d = qp.Ensemble(qp.hist, data=dict(bins=self.zgrid, pdfs=fid_hist))
self.add_data('output', sample_ens)
self.add_data('single_NZ', qp_d)
self.add_data('uncovered_cluster_file', bad_cluster)
self.add_data('uncovered_cluster_file', bad_clusters)

def _process_chunk(self, test_data, bootstrap_matrix, som_cluster_inds, spec_cluster_set, phot_cluster_set, sz, sweight, spec_som_clusterind, N_eff_p_num, N_eff_p_den, hist_vals, id_dict, start, end, first):
def _process_chunk(self, test_data, bootstrap_matrix, som_cluster_inds, spec_cluster_set, phot_cluster_set, sz, sweight, spec_som_clusterind, N_eff_p_num, N_eff_p_den, hist_vals, id_dict, start, end, first, bad_clusters):

for col in self.usecols:
if col not in test_data.keys(): # pragma: no cover
Expand All @@ -534,13 +536,30 @@ def _process_chunk(self, test_data, bootstrap_matrix, som_cluster_inds, spec_clu
self._do_chunk_output(id_dict, start, end, first)

chunk_phot_cluster_set = set(phot_som_clusterind)
useful_clusters = chunk_phot_cluster_set.intersection(spec_cluster_set)
phot_cluster_set.update(chunk_phot_cluster_set)

uncovered_clusters = phot_cluster_set - spec_cluster_set
bad_cluster = dict(uncovered_clusters=np.array(list(uncovered_clusters)))
print("the following clusters contain photometric data but not spectroscopic data:")
print(uncovered_clusters)

covered_clusters = phot_cluster_set - uncovered_clusters
if self.config.useful_clusters.size == 0:
self.useful_clusters = covered_clusters
else: # pragma: no cover
if set(self.config.useful_clusters) <= covered_clusters:
self.useful_clusters = self.config.useful_clusters
else:
print("Warning: input useful clusters is not a subset of spec-covered clusters."
+"Taking the intersection.")
self.useful_clusters = np.intersect1d(self.config.useful_clusters, np.asarray(list(covered_clusters)))
if self.useful_clusters.size == 0: # pragma: no cover
raise ValueError("Input useful clusters have no intersection with spec-covered clusters!")

useful_clusters = self.useful_clusters

tmp_neff_num = np.sum(test_data['weight'])
tmp_neff_den = np.sum(test_data['weight'] ** 2)


for i in range(self.config.nsamples):
bootstrap_indices = bootstrap_matrix[:,i]
bs_specz = sz[bootstrap_indices]
Expand Down
32 changes: 26 additions & 6 deletions tests/som/test_somoclu_summarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,20 @@ def one_algo(key, inform_class, summarizer_class, summary_kwargs):
fid_ens = qp.read(summarizer2.get_output(summarizer2.get_aliased_tag("single_NZ"), final_name=True))
meanz = fid_ens.mean().flatten()
assert np.isclose(meanz[0], 0.14414913252122552, atol=0.025)

full_useful_clusters = np.asarray(list(summarizer2.useful_clusters))
full_uncovered_clusters = np.asarray(list(np.setdiff1d(np.arange(31*31), full_useful_clusters)))

os.remove(summarizer2.get_output(summarizer2.get_aliased_tag("output"), final_name=True))
os.remove(f"tmpsomoclu_" + key + ".pkl")
return summary_ens
return summary_ens, full_useful_clusters, full_uncovered_clusters


def test_SomocluSOM():
summary_config_dict = {"n_rows": 21, "n_columns": 21, "column_usage": "colors"}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
_ = one_algo("SOMomoclu", inform_class, summarizerclass, summary_config_dict)
_,_,_ = one_algo("SOMomoclu", inform_class, summarizerclass, summary_config_dict)


def test_SomocluSOM_with_mag_and_colors():
Expand Down Expand Up @@ -98,19 +102,35 @@ def test_SomocluSOM_with_columns():
}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
_,_,_ = one_algo("SOMoclu_wmag", inform_class, summarizerclass, summary_config_dict)


_ = one_algo("SOMoclu_wmag", inform_class, summarizerclass, summary_config_dict)
def test_SomocluSOM_useful_clusters():
summary_config_dict = {"n_rows": 21, "n_columns": 21, "column_usage": "colors", "seed":0}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
_, full_useful_clusters, full_uncovered_clusters = one_algo("SOMomoclu1", inform_class, summarizerclass, summary_config_dict)

summary_config_dict = {"n_rows": 31, "n_columns": 31, "column_usage": "colors", "seed":0, "useful_clusters": np.arange(31*31)}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
_ = one_algo("SOMomoclu2", inform_class, summarizerclass, summary_config_dict)

summary_config_dict = {"n_rows": 31, "n_columns": 31, "column_usage": "colors", "seed":0, "useful_clusters": full_uncovered_clusters}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
_ = one_algo("SOMomoclu4", inform_class, summarizerclass, summary_config_dict)

def test_SomocluSOM_with_badinput():
def test_SomocluSOM_wrong_column():
summary_config_dict = {
"n_rows": 21,
"n_columns": 21,
"column_usage": "something",
"column_usage": "wrong_column",
"objid_name": "id",
}
inform_class = somoclu_som.SOMocluInformer
summarizerclass = somoclu_som.SOMocluSummarizer
try:
one_algo("SOMoclu_wrong", inform_class, summarizerclass, summary_config_dict)
_ = one_algo("SOMoclu_wrongcolumn", inform_class, summarizerclass, summary_config_dict)
except:
return

0 comments on commit b455da8

Please sign in to comment.