Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CosmoPower P(k,z) #85

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions soliket/cosmopower.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from cobaya.theories.cosmo import BoltzmannBase
from cobaya.typing import InfoDict
from cobaya.log import LoggedError

"""
Simple CosmoPower theory wrapper for Cobaya.
Expand All @@ -19,9 +20,13 @@
class CosmoPower(BoltzmannBase):
soliket_data_path: str = "soliket/data/CosmoPower"
network_path: str = "CP_paper/CMB"
network_path_pk: str = "CP_paper/PK"
cmb_tt_nn_filename: str = "cmb_TT_NN"
cmb_te_pcaplusnn_filename: str = "cmb_TE_PCAplusNN"
cmb_ee_nn_filename: str = "cmb_EE_NN"
pk_lin_nn_filename: str = "PKLIN_NN"
k_max: int = 1.
z: float = np.linspace(0.0, 2, 128)

extra_args: InfoDict = {}

Expand Down Expand Up @@ -51,29 +56,46 @@ def initialize(self):
restore=True,
restore_filename=os.path.join(base_path, self.cmb_ee_nn_filename),
)
self.cp_pk_nn = cp.cosmopower_NN(
restore=True,
restore_filename=os.path.join(self.soliket_data_path, self.network_path_pk, self.pk_lin_nn_filename),
)

if "lmax" not in self.extra_args:
self.extra_args["lmax"] = None

self.log.info(f"Loaded CosmoPower from directory {self.network_path}")

def calculate(self, state, want_derived=True, **params):
cmb_params = {}
network_params = {}

for par in self.renames:
if par in params:
cmb_params[par] = [params[par]]
network_params[par] = [params[par]]
else:
for r in self.renames[par]:
if r in params:
cmb_params[par] = [params[r]]
network_params[par] = [params[r]]
break

state["tt"] = self.cp_tt_nn.ten_to_predictions_np(cmb_params)[0, :]
state["te"] = self.cp_te_nn.predictions_np(cmb_params)[0, :]
state["ee"] = self.cp_ee_nn.ten_to_predictions_np(cmb_params)[0, :]
state["tt"] = self.cp_tt_nn.ten_to_predictions_np(network_params)[0, :]
state["te"] = self.cp_te_nn.predictions_np(network_params)[0, :]
state["ee"] = self.cp_ee_nn.ten_to_predictions_np(network_params)[0, :]
state["ell"] = self.cp_tt_nn.modes

# if "Pk_grid" in self.requested() or "Pk_interpolator" in self.requested():

network_params_z = {}

for k in network_params.keys():
network_params_z[k] = np.repeat(network_params[k], len(self.z))
network_params_z['z'] = self.z

var_pair=("delta_tot", "delta_tot")
nonlinear=True

state[("Pk_grid", bool(nonlinear)) + tuple(sorted(var_pair))] = self.cp_pk_nn.modes, self.z, self.cp_pk_nn.predictions_np(network_params_z)

def get_Cl(self, ell_factor=False, units="FIRASmuK2"):
cls_old = self.current_state.copy()

Expand Down Expand Up @@ -102,5 +124,48 @@ def get_Cl(self, ell_factor=False, units="FIRASmuK2"):

return cls

# def must_provide(self, **requirements):

# super().must_provide(**requirements)

# for k, v in self._must_provide.items():
# if isinstance(k, tuple) and k[0] == "Pk_grid":
# v = deepcopy(v)
# self.add_P_k_max(v.pop("k_max"), units="1/Mpc")
# # NB: Actually, only the max z is used, and the actual sampling in z
# # for computing P(k,z) is controlled by `perturb_sampling_stepsize`
# # (default: 0.1). But let's leave it like this in case this changes
# # in the future.
# self.add_z_for_matter_power(v.pop("z"))
# if v["nonlinear"]:
# if "non_linear" not in self.extra_args:
# # this is redundant with initialisation, but just in case
# self.extra_args["non_linear"] = non_linear_default_code
# elif self.extra_args["non_linear"] == non_linear_null_value:
# raise LoggedError(
# self.log, ("Non-linear Pk requested, but `non_linear: "
# f"{non_linear_null_value}` imposed in "
# "`extra_args`"))
# pair = k[2:]
# if pair == ("delta_tot", "delta_tot"):
# self.collectors[k] = Collector(
# method="get_pk_and_k_and_z",
# kwargs=v,
# post=(lambda P, kk, z: (kk, z, np.array(P).T)))
# else:
# raise LoggedError(self.log, "NotImplemented in cosmopower: %r", pair)

# return needs


# def get_Pk_grid(self, params, var_pair=("delta_tot", "delta_tot"), nonlinear=False,
# extrap_kmin=None, extrap_kmax=None):

# return self.current_state['k'], self.current_state['pk_z'], self.current_state['pk']


def get_can_support_params(self):
return ["omega_b", "omega_cdm", "h", "logA", "ns", "tau_reio"]

def get_can_provide(self):
return ['Cl', 'Pk_grid', 'Pk_interpolator']
56 changes: 56 additions & 0 deletions soliket/tests/test_cosmopower.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@
"H0": {"value": "lambda h: h * 100.0"},
}

des_params = {
'DES_DzL1': 0.001,
'DES_DzL2': 0.002,
'DES_DzL3': 0.001,
'DES_DzL4': 0.003,
'DES_DzL5': 0,
'DES_b1': 1.45,
'DES_b2': 1.55,
'DES_b3': 1.65,
'DES_b4': 1.8,
'DES_b5': 2.0,
'DES_DzS1': -0.001,
'DES_DzS2': -0.019,
'DES_DzS3': 0.009,
'DES_DzS4': -0.018,
'DES_m1': 0.012,
'DES_m2': 0.012,
'DES_m3': 0.012,
'DES_m4': 0.012,
'DES_AIA': 1,
'DES_alphaIA': 1,
'DES_z0IA': 0.62,
}

info_dict = {
"params": fiducial_params,
"likelihood": {
Expand All @@ -39,6 +63,27 @@
},
}

def lss_part_likelihood(_self):
results = _self.provider.get_Pk_interpolator().P(0.1, 1.0)
# results = _self.provider.get_Pk_grid()
return 1


info_dict_pk = {
"params": fiducial_params,
"likelihood": {'lss': {'external': lss_part_likelihood, 'requires': {'Pk_interpolator': {'z' : np.linspace(0.0, 2, 128),
'k_max': 1.}}}
},
"theory": {
"soliket.CosmoPower": {
"soliket_data_path": "soliket/data/CosmoPower",
"stop_at_error": True,
"provides": 'Pk_grid',
}
# 'camb': {"stop_at_error": True,}
}
}


@pytest.mark.skipif(not HAS_COSMOPOWER, reason='test requires cosmopower')
def test_cosmopower_theory():
Expand Down Expand Up @@ -77,3 +122,14 @@ def test_cosmopower_against_camb():

assert np.allclose(cp_cls['tt'][nanmask], camb_cls['tt'][nanmask], rtol=1.e-2)
assert np.isclose(logL_camb, logL_cp, rtol=1.e-1)


@pytest.mark.skipif(not HAS_COSMOPOWER, reason='test requires cosmopower')
def test_cosmopower_pkgrid():

model_cp = get_model(info_dict_pk)

logL_cp = float(model_cp.loglikes({})[0])

assert np.isfinite(logL_cp)